diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 218b30bd9e33..fb8182b21dda 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -538,6 +538,7 @@ def _test_quantizer( expected_node_occurrence, expected_node_list=None, is_qat=False, + debug=False, ): m_eager = model.train() if is_qat else model.eval() @@ -556,6 +557,8 @@ def _test_quantizer( prepare_model = copy.deepcopy(m) m = convert_pt2e(m) convert_model = copy.deepcopy(m) + if debug: + convert_model.print_readable(True) pt2_quant_output = m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() @@ -751,9 +754,11 @@ def test_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -1346,9 +1351,11 @@ def test_linear_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1401,9 +1408,11 @@ def test_linear_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1472,9 +1481,11 @@ def test_linear_binary_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1694,9 +1705,11 @@ def test_qat_conv2d_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] @@ -1741,9 +1754,11 @@ def test_qat_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -1865,6 +1880,410 @@ def test_qat_dynamic_quant_linear(self): is_qat=True, ) + @skipIfNoX86 + def test_set_module_name_qconfig(self): + """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. + + Expect that all linear layers within the submodule `sub` are quantized. + """ + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=False) + self.linear2 = torch.nn.Linear(10, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to `None` and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of two linear layers from `sub` + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # two Q/DQ pairs for two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # This module name has underscores, which can be part of a mangled name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "foo_bar", xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = capture_pre_autograd_graph(m, example_inputs) + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + + @skipIfNoX86 + def test_set_module_name_and_module_type_case1(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are not quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with default config and then `None` for all `Linear`. + # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # last linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_case2(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with None and then default config for a all `Linear`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input and output of the first and second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the first and second linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # Q/DQ for first lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # Q/DQ for second lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # last linear is not quantized + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_for_dynamic_quant(self): + """Test that quantize a specific submodule for dynamic quantization.""" + + with override_quantized_engine("x86"), torch.no_grad(): + for is_qat in [False, True]: + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # only quantize `q_proj` `v_proj` + dynamic_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=True, is_qat=is_qat + ) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig("q_proj", dynamic_config) + .set_module_name_qconfig("v_proj", dynamic_config) + ) + node_occurrence = { + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # dequantize the weight of q_proj and v_proj + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # v_proj + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoX86 + def test_set_module_name_with_mixed_configs(self): + """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. + + The config for 'v_proj' will always be ignored and raise a warning. + """ + with override_quantized_engine("x86"), torch.no_grad(): + with self.assertWarns(UserWarning) as context: + for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( + [False, True], repeat=4 + ): + if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: + continue + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig( + "q_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=q_is_qat, is_dynamic=q_is_dynamic + ), + ) + .set_module_name_qconfig( + "v_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=v_is_qat, is_dynamic=v_is_dynamic + ), + ) + ) + quant_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequant_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # quantize and dequantize the input + quant_op: 1, + dequant_op: 1, + # only `q_proj` was quantized, dequantize its weight + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # quantize and dequantize the input + quant_op, + dequant_op, + # q_proj + torch.ops.aten.linear.default, + # k_proj/v_proj + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=q_is_qat, + ) + warning_msg = ( + "Mixed QAT and Non-QAT" + if q_is_qat != v_is_qat + else "Mixed dynamic and static" + ) + self.assertTrue( + any( + warning_msg in msg + for msg in [str(w.message) for w in context.warnings] + ) + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_with_mixed_configs(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. + + Expect that only the last linear(`sub`) is quantized using static quantization. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) + ).set_module_type_qconfig( + torch.nn.Linear, + xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # Q/DQ pairs for the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + @skipIfNoX86 def test_filter_conv2d_recipe(self): """ @@ -1994,12 +2413,12 @@ def test_attention_block(self): ) node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5 - if annotate_matmul - else 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7 - if annotate_matmul - else 3, + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + 5 if annotate_matmul else 1 + ), + torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( + 7 if annotate_matmul else 3 + ), # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index f948dbb112dc..68c90f5cf57f 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -48,3 +48,37 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): ((user not in partition_nodes) or _is_sym_size_node(user)) for user in node.users ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 89e4966bf4eb..6eecabb6fee0 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -15,8 +15,11 @@ Set, Tuple, TYPE_CHECKING, + Union, ) +from typing_extensions import TypeAlias + import torch import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( @@ -37,8 +40,9 @@ Quantizer, SharedQuantizationSpec, ) + +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - _is_annotated, get_bias_qspec, get_input_act_qspec, get_output_act_qspec, @@ -53,6 +57,9 @@ SourcePartition, ) +FilterFn: TypeAlias = Callable[[List[Node]], bool] + + if TYPE_CHECKING: from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor @@ -68,6 +75,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): # * Node as output node of a fusion pattern. # * The fusion pattern supports int8 data type. # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. _is_output_of_quantized_pattern: bool = False @@ -102,6 +110,91 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" +def _skip_annotate(nodes: List[Node], filter_fn: Optional[FilterFn] = None) -> bool: + """Determine whether to skip annotation for a list of nodes.""" + + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + return True + + # 2) Proceed annotate if a) a filter function is provided + # and b) the given nodes list passes the filter function check. + if filter_fn and filter_fn(nodes): + return False + + return True + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: List[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: List[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: List[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( @@ -294,16 +387,63 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() +def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +@dataclass +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ + + qat_state: Optional[bool] + dynamic_state: Optional[bool] + + class X86InductorQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() def __init__(self): super().__init__() - self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.global_config: Optional[QuantizationConfig] = None self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} + self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -327,7 +467,78 @@ def get_supported_operator_for_quantization_config( return ops return [] + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + qat_state = None + dynamic_state = None + + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + assert qat_state == qconfig.is_qat, ( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + assert dynamic_state == input_activation_spec.is_dynamic, ( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) + + def _need_skip_config( + self, quantization_config: Optional[QuantizationConfig] + ) -> bool: + """Check if the provided quantization config is valid for X86InductorQuantizer. + + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. + """ + if quantization_config is None: + return False + + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat + ): + warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") + need_skip = True + if current_mode.dynamic_state is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.dynamic_state != input_activation_spec.is_dynamic + ): + warnings.warn( + "Mixed dynamic and static quantization config is not supported." + ) + need_skip = True + return need_skip + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.") + return self self.global_config = quantization_config return self @@ -339,6 +550,7 @@ def get_global_quantization_config(self): ) return self.global_config + @_config_checker def set_function_type_qconfig( self, function_type: Callable, @@ -357,6 +569,7 @@ def set_function_type_qconfig( ) return self + @_config_checker def set_module_type_qconfig( self, module_type: torch.nn.Module, @@ -373,6 +586,19 @@ def set_module_type_qconfig( ) return self + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + def _set_aten_operator_qconfig( self, operator_type: torch._ops.OpOverloadPacket, @@ -386,22 +612,16 @@ def _set_aten_operator_qconfig( ) return self - def _get_aten_operator_qconfig( - self, - operator_type: torch._ops.OpOverloadPacket, - ) -> Optional[QuantizationConfig]: - if operator_type in self.operator_type_qconfig: - assert operator_type in quantizable_ops - return self.operator_type_qconfig[operator_type] - return self.global_config if operator_type in default_quantizable_ops else None - def _annotate_conv_node_helper( self, conv_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return input_qspec_map = {} input_node = conv_node.args[0] assert isinstance(input_node, Node) @@ -428,9 +648,12 @@ def _annotate_linear_node_helper( self, linear_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return input_qspec_map = {} assert linear_node.target in (torch.ops.aten.linear.default,) has_bias = len(linear_node.args) == 3 @@ -504,65 +727,92 @@ def _get_input_idx_for_binary_node( return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """just handling global spec for now""" - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_for_dynamic_quantization_config(model) - else: - model = self._annotate_for_static_quantization_config(model) + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + return model - def _annotate_for_static_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - r""" + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model from start to the end. If a pattern supports computation with int8 data type and inputs connected to quantized patterns, annotate its inputs as quantized pattern. - Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns, - such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, - we need to annotate the output of this pattern. """ # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_conv2d_fusion_pattern(model) - self._annotate_linear_fusion_pattern(model) - self._annotate_matmul(model) + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - for node in model.graph.nodes: - self._annotate_propagation_quantizable_pattern(node) - - # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized - # in inputs. So, we can fuse dq-operator-q into a quantized op. - # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ - # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - for node in model.graph.nodes: - self._annotate_output_for_int8_in_int8_out_pattern(node) - - return model - def _annotate_for_dynamic_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - self._annotate_linear_fusion_pattern(model) - return model + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) def _annotate_qat_conv2d_fusion_pattern( - self, model: torch.fx.GraphModule, config: QuantizationConfig + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): # Annotate QAT Specific patterns - self._annotate_qat_conv2d_bn_binary_unary(model, config) - self._annotate_qat_conv2d_bn_binary(model, config) - self._annotate_qat_conv2d_bn_unary(model, config) - self._annotate_qat_conv2d_bn(model, config) + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) def _annotate_qat_conv2d_bn_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -602,25 +852,34 @@ def _annotate_qat_conv2d_bn_binary_unary( ): continue - if _is_annotated([unary_node, binary_node, bn_output_node, conv_node]): + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, - ) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) @@ -628,7 +887,10 @@ def _annotate_qat_conv2d_bn_binary_unary( _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] @@ -662,29 +924,37 @@ def _annotate_qat_conv2d_bn_binary( ): continue - if _is_annotated([binary_node, bn_output_node, conv_node]): + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(binary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -716,23 +986,31 @@ def _annotate_qat_conv2d_bn_unary( ): continue - if _is_annotated([unary_node, bn_output_node, conv_node]): + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(unary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] @@ -749,60 +1027,87 @@ def _annotate_qat_conv2d_bn( ): continue - if _is_annotated([bn_output_node, conv_node]): + if _skip_annotate([bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - bn_output_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + bn_output_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(bn_output_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.conv2d.default): - if config.is_qat: - # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat - self._annotate_qat_conv2d_fusion_pattern(model, config) - self._annotate_conv2d_binary_unary(model, config) - self._annotate_conv2d_binary(model, config) - self._annotate_conv2d_unary(model, config) - self._annotate_conv2d(model, config) - - def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - if config.input_activation and not config.input_activation.is_dynamic: - # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) - self._annotate_linear(model, config) - - def _annotate_matmul(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default): - for node in model.graph.nodes: - if node.target == torch.ops.aten.matmul.default and not _is_annotated( - [node] - ): - input_qspec_map = {} - matmul_node = node - for input_node in matmul_node.args: - input_qspec_map[input_node] = get_input_act_qspec(config) - matmul_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=input_qspec_map, - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or ( + quantization_config.input_activation + and not quantization_config.input_activation.is_dynamic + ): + # Weiwen: Dynamic Quant of linear unary will be supported in next step + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) + + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) def _annotate_conv2d_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add + unary op fused_partitions = find_sequential_partitions( @@ -830,8 +1135,13 @@ def _annotate_conv2d_binary_unary( ): # No conv node found to be fused with add continue - if _is_annotated([unary_node, binary_node, conv_node]): + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -847,7 +1157,10 @@ def _annotate_conv2d_binary_unary( ) def _annotate_conv2d_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add fused_partitions = find_sequential_partitions( @@ -876,8 +1189,13 @@ def _annotate_conv2d_binary( ): # No conv node found to be fused with add continue - if _is_annotated([binary_node, conv_node]): + if _skip_annotate([binary_node, conv_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -890,7 +1208,10 @@ def _annotate_conv2d_binary( ) def _annotate_conv2d_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -916,8 +1237,13 @@ def _annotate_conv2d_unary( or conv_node.target != torch.ops.aten.conv2d.default ): continue - if _is_annotated([unary_node, conv_node]): + if _skip_annotate([unary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, @@ -925,7 +1251,10 @@ def _annotate_conv2d_unary( ) def _annotate_conv2d( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: conv_partitions = get_source_partitions( gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] @@ -941,15 +1270,21 @@ def _annotate_conv2d( ): raise ValueError(f"{conv_node} is not an aten conv2d operator") # skip annotation if it is already annotated - if _is_annotated([conv_node]): + if _skip_annotate([conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, True, quantization_config) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig + self, + node: Node, + quantization_config: Optional[QuantizationConfig], ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + maxpool_node = node if _is_any_annotated( [ @@ -957,6 +1292,7 @@ def _annotate_maxpool2d( ] ): return + input_node = maxpool_node.args[0] assert isinstance(input_node, Node) input_qspec_map = {} @@ -970,6 +1306,9 @@ def _annotate_maxpool2d( def _annotate_cat( self, node: Node, quantization_config: QuantizationConfig ) -> None: + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return cat_node = node input_nodes = cat_node.args[0] assert isinstance(input_nodes, Sequence) @@ -994,13 +1333,25 @@ def _annotate_cat( _is_output_of_quantized_pattern=True, ) - def _annotate_propagation_quantizable_pattern(self, node: Node) -> None: + def _annotate_propagation_quantizable_pattern_entry( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in gm.graph.nodes: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: # Propagate annotation to quantizable patterns. if ( (node.target in propagation_quantizable_ops) and (not _is_any_annotated([node])) and (node.op == "call_function") - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] ): def is_all_inputs_connected_to_quantized_op(input_nodes): @@ -1010,11 +1361,23 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): return False return True + if _skip_annotate([node], filter_fn): + return + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + if node.target is torch.ops.aten.max_pool2d.default: # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not input_nodes_to_check = [node.all_input_nodes[0]] if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." + ) return + self._annotate_maxpool2d(node, quantization_config) return elif node.target is torch.ops.aten.cat.default: @@ -1057,18 +1420,24 @@ def _annotate_output_share_observer_as_input( ) return - def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, + model: torch.fx.GraphModule, + ): + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 """ edge_or_node: Tuple[Node, Node] - if ( - (node.target in int8_in_int8_out_ops) - and (_is_any_annotated([node])) - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] - ): + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): if node.target == torch.ops.aten.max_pool2d.default: maxpool_node = node if not _is_all_annotated( @@ -1077,6 +1446,7 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: ] ): return + # Get the quantization_annotation from getitem_node maxpool_node_quantization_annotation = ( maxpool_node.meta[QUANT_ANNOTATION_KEY] @@ -1101,7 +1471,10 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: return def _annotate_linear( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: linear_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] @@ -1120,12 +1493,15 @@ def _annotate_linear( ): raise ValueError(f"{linear_node} is not an aten linear operator") # skip annotation if it is already annotated - if _is_annotated([linear_node]): + if _skip_annotate([linear_node], filter_fn): continue self._annotate_linear_node_helper(linear_node, True, quantization_config) def _annotate_linear_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: postop_list = [ torch.nn.ReLU, @@ -1147,8 +1523,13 @@ def _annotate_linear_unary( torch.ops.aten.linear.default, ): continue - if _is_annotated([unary_node, linear_node]): + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, @@ -1158,7 +1539,8 @@ def _annotate_linear_unary( def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # linear + binary_op + (optional) unary op binary_op_list = [operator.add] @@ -1215,8 +1597,13 @@ def _annotate_linear_binary_unary( if unary_node is None else [unary_node, binary_node, linear_node] ) - if _is_annotated(node_list): + if _skip_annotate(node_list, filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) continue + self._annotate_linear_node_helper( linear_node, False, quantization_config ) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index ae9ae60b8a3b..88ccc1454f44 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -23,6 +23,7 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -193,40 +194,6 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - def module_name_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type