diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 5c9528debbe..838dd44733e 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,11 +11,9 @@ # Utility functions for TOSAQuantizer # -from typing import cast, Sequence +from typing import cast -import torch -from torch._subclasses import FakeTensor -from torch.fx import GraphModule, Node +from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY @@ -45,62 +43,3 @@ def mark_node_as_annotated(node: Node) -> None: if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() node.meta[Q_ANNOTATION_KEY]._annotated = True - - -def is_ok_for_quantization(node: Node, gm: GraphModule): - """Check if an node can be quantized. The node can not be quantized if: - - The node does not output a float tensor or, - - The node outputs a large scalar. - """ - return not (is_non_float_tensor(node) or is_large_scalar(node, gm)) - - -def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): - targets = target_str.split(".") - for target in targets[:-1]: - module = module.get_submodule(target) - return getattr(module, targets[-1]) - - -def is_large_scalar(node: Node, gm: GraphModule): - """Check if input is a large scalar value. So that we can skip quantization for the node - since histc op (in HistogramObserver) only works for values up to certain upper bound - """ - if node.op == "get_attr" and isinstance(node.target, str): - tensor = get_node_target(gm, node.target) - # torch.histc works until this upper bound - HISTC_UPPER_BOUND = 3.4028235e15 - return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND - return False - - -def is_non_float_tensor(node: Node) -> bool: - """Check if the output of a node has a data type other than `torch.float32`. - - If the output is not `torch.float32`, quantization cannot be performed, as - observers only work with floating-point tensors. - - Args: - node (Node): The node to check the output(s) for. - - Returns: - bool: `True` if the data type is not float32, otherwise `False`. - - Note: - - If `node.meta["val"]` is a `list`, the function returns `True` if **any** - element is **not** an instance of `FakeTensor` or does **not** have - `torch.float32` as its data type. - - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the - function returns True. - """ - if "val" in node.meta and isinstance(node.meta["val"], Sequence): - return any( - not isinstance(fake_tensor, FakeTensor) - or fake_tensor.dtype != torch.float32 - for fake_tensor in node.meta["val"] - ) - - if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): - return True - - return node.meta["val"].dtype != torch.float32 diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index e18f14c15f9..fbc1f8f7cb8 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -6,13 +6,14 @@ import logging import operator from dataclasses import dataclass -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Sequence import torch import torch.fx import torch.nn.functional as F from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa_utils import get_node_debug_info +from torch._subclasses import FakeTensor from torch.fx import Node from torchao.quantization.pt2e.quantizer import ( @@ -24,7 +25,6 @@ from .arm_quantizer_utils import ( is_annotated, - is_ok_for_quantization, is_output_annotated, mark_node_as_annotated, ) @@ -78,9 +78,16 @@ def _is_ok_for_quantization( """ # Check output if quant_properties.quant_output is not None: - if not is_ok_for_quantization(node, gm): # type: ignore[attr-defined] + if _is_non_float_tensor(node): logger.debug( - f"Could not quantize node due to output: " + "Could not quantize non float tensor for the following output node: " + f"{get_node_debug_info(node, gm)}" + ) + + return False + elif _is_large_scalar(node, gm): + logger.debug( + "Could not quantize large scalar node for the following output node: " f"{get_node_debug_info(node, gm)}" ) @@ -99,10 +106,18 @@ def _is_ok_for_quantization( raise TypeError( f"n_arg must be a Node instance, got {type(n_arg).__name__!r}" ) - if not is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined] + + if _is_non_float_tensor(n_arg): logger.debug( - f'could not quantize node due to input "{node}": ' - f"{get_node_debug_info(node, gm)}" + "Could not quantize non float tensor for the following input " + f"node: {get_node_debug_info(node, gm)}" + ) + + return False + elif _is_large_scalar(n_arg, gm): + logger.debug( + "Could not quantize large scalar node for the following input " + f"node: {get_node_debug_info(node, gm)}" ) return False @@ -110,6 +125,58 @@ def _is_ok_for_quantization( return True +def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str): + targets = target_str.split(".") + for target in targets[:-1]: + module = module.get_submodule(target) + return getattr(module, targets[-1]) + + +def _is_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the + node since histc op (in HistogramObserver) only works for values up to certain upper + bound. + """ + if node.op == "get_attr" and isinstance(node.target, str): + tensor = _get_node_target(gm, node.target) + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_non_float_tensor(node: Node) -> bool: + """Check if the output of a node has a data type other than `torch.float32`. + + If the output is not `torch.float32`, quantization cannot be performed, as + observers only work with floating-point tensors. + + Args: + node (Node): The node to check the output(s) for. + + Returns: + bool: `True` if the data type is not float32, otherwise `False`. + + Note: + - If `node.meta["val"]` is a `list`, the function returns `True` if **any** + element is **not** an instance of `FakeTensor` or does **not** have + `torch.float32` as its data type. + - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the + function returns True. + """ + if "val" in node.meta and isinstance(node.meta["val"], Sequence): + return any( + not isinstance(fake_tensor, FakeTensor) + or fake_tensor.dtype != torch.float32 + for fake_tensor in node.meta["val"] + ) + + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + + return node.meta["val"].dtype != torch.float32 + + def _annotate_input(node: Node, quant_property: _QuantProperty): if is_annotated(node): raise RuntimeError(