Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 2 additions & 63 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# Utility functions for TOSAQuantizer
#

from typing import cast, Sequence
from typing import cast

import torch
from torch._subclasses import FakeTensor
from torch.fx import GraphModule, Node
from torch.fx import Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
Expand Down Expand Up @@ -45,62 +43,3 @@ def mark_node_as_annotated(node: Node) -> None:
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[Q_ANNOTATION_KEY]._annotated = True


def is_ok_for_quantization(node: Node, gm: GraphModule):
"""Check if an node can be quantized. The node can not be quantized if:
- The node does not output a float tensor or,
- The node outputs a large scalar.
"""
return not (is_non_float_tensor(node) or is_large_scalar(node, gm))


def get_node_target(module: torch.nn.Module | GraphModule, target_str: str):
targets = target_str.split(".")
for target in targets[:-1]:
module = module.get_submodule(target)
return getattr(module, targets[-1])


def is_large_scalar(node: Node, gm: GraphModule):
"""Check if input is a large scalar value. So that we can skip quantization for the node
since histc op (in HistogramObserver) only works for values up to certain upper bound
"""
if node.op == "get_attr" and isinstance(node.target, str):
tensor = get_node_target(gm, node.target)
# torch.histc works until this upper bound
HISTC_UPPER_BOUND = 3.4028235e15
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
return False


def is_non_float_tensor(node: Node) -> bool:
"""Check if the output of a node has a data type other than `torch.float32`.

If the output is not `torch.float32`, quantization cannot be performed, as
observers only work with floating-point tensors.

Args:
node (Node): The node to check the output(s) for.

Returns:
bool: `True` if the data type is not float32, otherwise `False`.

Note:
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
element is **not** an instance of `FakeTensor` or does **not** have
`torch.float32` as its data type.
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
function returns True.
"""
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
return any(
not isinstance(fake_tensor, FakeTensor)
or fake_tensor.dtype != torch.float32
for fake_tensor in node.meta["val"]
)

if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True

return node.meta["val"].dtype != torch.float32
81 changes: 74 additions & 7 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import logging
import operator
from dataclasses import dataclass
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Sequence

import torch
import torch.fx
import torch.nn.functional as F
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch._subclasses import FakeTensor

from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (
Expand All @@ -24,7 +25,6 @@

from .arm_quantizer_utils import (
is_annotated,
is_ok_for_quantization,
is_output_annotated,
mark_node_as_annotated,
)
Expand Down Expand Up @@ -78,9 +78,16 @@ def _is_ok_for_quantization(
"""
# Check output
if quant_properties.quant_output is not None:
if not is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
if _is_non_float_tensor(node):
logger.debug(
f"Could not quantize node due to output: "
"Could not quantize non float tensor for the following output node: "
f"{get_node_debug_info(node, gm)}"
)

return False
elif _is_large_scalar(node, gm):
logger.debug(
"Could not quantize large scalar node for the following output node: "
f"{get_node_debug_info(node, gm)}"
)

Expand All @@ -99,17 +106,77 @@ def _is_ok_for_quantization(
raise TypeError(
f"n_arg must be a Node instance, got {type(n_arg).__name__!r}"
)
if not is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]

if _is_non_float_tensor(n_arg):
logger.debug(
f'could not quantize node due to input "{node}": '
f"{get_node_debug_info(node, gm)}"
"Could not quantize non float tensor for the following input "
f"node: {get_node_debug_info(node, gm)}"
)

return False
elif _is_large_scalar(n_arg, gm):
logger.debug(
"Could not quantize large scalar node for the following input "
f"node: {get_node_debug_info(node, gm)}"
)

return False

return True


def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str):
targets = target_str.split(".")
for target in targets[:-1]:
module = module.get_submodule(target)
return getattr(module, targets[-1])


def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
"""Check if input is a large scalar value. So that we can skip quantization for the
node since histc op (in HistogramObserver) only works for values up to certain upper
bound.
"""
if node.op == "get_attr" and isinstance(node.target, str):
tensor = _get_node_target(gm, node.target)
# torch.histc works until this upper bound
HISTC_UPPER_BOUND = 3.4028235e15
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
return False


def _is_non_float_tensor(node: Node) -> bool:
"""Check if the output of a node has a data type other than `torch.float32`.

If the output is not `torch.float32`, quantization cannot be performed, as
observers only work with floating-point tensors.

Args:
node (Node): The node to check the output(s) for.

Returns:
bool: `True` if the data type is not float32, otherwise `False`.

Note:
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
element is **not** an instance of `FakeTensor` or does **not** have
`torch.float32` as its data type.
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
function returns True.
"""
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
return any(
not isinstance(fake_tensor, FakeTensor)
or fake_tensor.dtype != torch.float32
for fake_tensor in node.meta["val"]
)

if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True

return node.meta["val"].dtype != torch.float32


def _annotate_input(node: Node, quant_property: _QuantProperty):
if is_annotated(node):
raise RuntimeError(
Expand Down
Loading