Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
return quantization_config


@functools.lru_cache
def get_symmetric_a16w8_quantization_config(
is_per_channel: bool = True,
is_qat: bool = False,
is_dynamic: bool = False,
weight_qmin: int = -127,
weight_qmax: int = 127,
):
"""
16A8W quantization config: 16-bit activations, 8-bit weights.

This configuration provides better accuracy than 8A8W while maintaining
reasonable memory usage through 8-bit weights.

Args:
is_per_channel: Whether to use per-channel quantization for weights
is_qat: Whether this is for Quantization Aware Training
is_dynamic: Whether to use dynamic quantization
weight_qmin: Minimum quantization value for weights
weight_qmax: Maximum quantization value for weights

Returns:
QuantizationConfig with 16-bit activations and 8-bit weights
"""
extra_args: Dict[str, Any] = {"eps": 2**-12}

# Setup observer/fake-quant for 16-bit activations
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
# HistogramObserver works well for 16-bit range
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]

# 16-bit activation quantization spec
act_quantization_spec = QuantizationSpec(
dtype=torch.int16,
quant_min=torch.iinfo(torch.int16).min, # -32768
quant_max=torch.iinfo(torch.int16).max, # 32767
qscheme=torch.per_tensor_symmetric,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)

# Instead of reconstructing quantization_config, just clone and update as needed
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
base_config = get_symmetric_quantization_config(
is_per_channel=is_per_channel,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
# Replace activation quantization spec with 16-bit version
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec, # 16-bit input activations
None,
base_config.weight, # 8-bit weights from base config
None,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec, # 16-bit input activations
act_quantization_spec, # 16-bit output activations
base_config.weight, # 8-bit weights from base config
None,
)
return quantization_config


NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
a Node and returns whether the node should be annotated or not.
Expand Down
Loading