Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 94 additions & 4 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide quantization configuration helpers for the Arm backend.

Define a small dataclass to carry activation/weight/bias specs and helper
accessors that validate specs before use. Use this module to build and validate
quantization specs consumed by the annotator.

"""

# pyre-unsafe

Expand All @@ -19,13 +26,38 @@

@dataclass(eq=True, frozen=True)
class QuantizationConfig:
"""Provide a container for quantization specs.

Hold optional specs for input/output activations, weights, and bias, and
expose validated accessors.

Attributes:
input_activation (QuantizationSpec | None): Spec for input activations.
output_activation (QuantizationSpec | None): Spec for output activations.
weight (QuantizationSpec | None): Spec for weights.
bias (QuantizationSpec | None): Spec for bias values.

"""

input_activation: QuantizationSpec | None
output_activation: QuantizationSpec | None
weight: QuantizationSpec | None
bias: QuantizationSpec | None

def get_input_act_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
"""Get the validated input activation spec.

Validate that the input activation qscheme is supported before
returning the spec.

Returns:
QuantizationSpec | None: Input activation spec, or ``None`` when
unset.

Raises:
ValueError: If the qscheme is not per-tensor affine or symmetric.

"""
if self.input_activation is None:
return None
# Validate that input_activation uses a supported qscheme
Expand All @@ -39,7 +71,19 @@ def get_input_act_qspec(self) -> QuantizationSpec | None:
return self.input_activation

def get_output_act_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
"""Get the validated output activation spec.

Validate that the output activation qscheme is supported before
returning the spec.

Returns:
QuantizationSpec | None: Output activation spec, or ``None`` when
unset.

Raises:
ValueError: If the qscheme is not per-tensor affine or symmetric.

"""
if self.output_activation is None:
return None
# Validate that output_activation uses a supported qscheme
Expand All @@ -53,7 +97,18 @@ def get_output_act_qspec(self) -> QuantizationSpec | None:
return self.output_activation

def get_weight_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
"""Get the validated weight spec.

Validate that the weight qscheme is supported (per-tensor or
per-channel symmetric) before returning the spec.

Returns:
QuantizationSpec | None: Weight spec, or ``None`` when unset.

Raises:
ValueError: If the qscheme is not a supported symmetric scheme.

"""
if self.weight is None:
return None
# Validate that weight uses a supported qscheme
Expand All @@ -65,11 +120,46 @@ def get_weight_qspec(self) -> QuantizationSpec | None:
return self.weight

def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""
"""Get the derived or validated bias spec.

For conv/linear ops, derive bias qparams from the input/weight observers.
Otherwise, validate a user-provided floating-point bias spec.

Args:
node (torch.fx.Node): Node whose bias spec is requested.

Returns:
QuantizationSpec | None: Derived or provided bias spec, or ``None``
when unset.

Raises:
ValueError: If deriving qparams sees an unexpected number of
observers/fake-quantizers, or if a provided bias dtype is not
floating-point.

"""

def _derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute bias scale/zero-point from activation/weight observers.

Expect two observers or fake-quantize modules: one for the input
activation and one for the weight. The bias scale is the product of
input and weight scales, and the zero-point is a tensor of zeros.

Args:
obs_or_fqs (list[ObserverOrFakeQuantize]): Observers/fake-quant
in order ``[act, weight]``.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Bias scale tensor and
integer zero-point tensor.

Raises:
ValueError: If the list does not contain exactly two items.

"""
# Validate expected number of observers/fake-quantizes
if len(obs_or_fqs) != 2:
raise ValueError(
Expand Down
Loading