Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 113 additions & 63 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@
)
from torch import fx
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
ComposableQuantizer,
DerivedQuantizationSpec,
Expand Down Expand Up @@ -154,78 +160,120 @@ def get_supported_operators(cls) -> list[OperatorConfig]:


# Quantization Specification used by Neutron NPU
act_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)

wgt_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
ch_axis=0,
)
def act_qspec(is_qat: bool):
eps = 2**-12
observer_or_fake_quant_ctr = (
FusedMovingAvgObsFakeQuantize.with_args(
observer=MovingAverageMinMaxObserver, eps=eps
)
if is_qat
else HistogramObserver.with_args(eps=eps)
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
)


def wgt_qspec(is_qat: bool):
observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
if is_qat
else MinMaxObserver
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
ch_axis=0,
)


def wgt_fc_qspec(is_qat: bool):
observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
if is_qat
else MinMaxObserver
)

return QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
)

wgt_fc_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)

# Is set by the *PatternQuantizer directly.
bias_qspec = None


class NeutronQuantizer(ComposableQuantizer):
def __init__(self, neutron_target_spec: NeutronTargetSpec):
def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False):
self.neutron_target_spec = neutron_target_spec
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
self.is_qat = is_qat

static_qconfig = QuantizationConfig(
act_qspec(is_qat=is_qat),
act_qspec(is_qat=is_qat),
wgt_qspec(is_qat=is_qat),
None,
)
static_fc_qconfig = QuantizationConfig(
act_qspec(is_qat=is_qat),
act_qspec(is_qat=is_qat),
wgt_fc_qspec(is_qat=is_qat),
None,
)

OpQuantizer = NeutronAtenQuantizer
super().__init__(
[
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
NeutronAtenQuantizer(CatPattern(), static_qconfig),
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
NeutronAtenQuantizer(ConvTranspose2dPattern(), static_qconfig),
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
NeutronAtenQuantizer(MulTensorPattern(), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SliceTensorPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(MulTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SliceTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
OpQuantizer(TransposeIntPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
]
)

# Mapping ops defined in quantizer partition types to its quantizer
self.op_to_quantizer = {
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
Expand All @@ -235,7 +283,9 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
}
self.cluster_quantizers = [
NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig)
NeutronAtenQuantizer(
ActivationsConcatClusterPattern(self, is_qat=is_qat), static_qconfig
)
]

def transform_for_annotation(
Expand Down Expand Up @@ -288,7 +338,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
continue

if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(node, act_qspec)
_annotate_output_qspec(node, act_qspec(self.is_qat))
self._mark_input_node_as_annotated(node)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down
62 changes: 46 additions & 16 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torch import fx
from torch._ops import OpOverload
from torch.fx import Node
from torchao.quantization.pt2e import PerChannelMinMaxObserver
from torchao.quantization.pt2e import (
FakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
Expand Down Expand Up @@ -59,7 +63,8 @@ class PartitionAnchors:
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
] = field(default_factory=list)
weights: list[
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
tuple[fx.Node, NodeArgsIdx]
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
] = field(default_factory=list)
biases: list[
tuple[fx.Node, NodeArgsIdx]
Expand All @@ -69,12 +74,18 @@ class PartitionAnchors:
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
output: list[
tuple[fx.Node]
| tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec],
| tuple[
fx.Node,
FixedQParamsQuantizationSpec | SharedQuantizationSpec,
],
] = field(default_factory=list)
empty: bool = False


class QuantizationPattern(ABC):
def __init__(self, is_qat: bool = False):
self.is_qat = is_qat

@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
Expand Down Expand Up @@ -148,11 +159,12 @@ def get_anchors_for_fixed_quant_specs(
zero_point: int,
quant_min: int = -128,
quant_max: int = 127,
is_qat: bool = False,
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
qspec_or_fake_quantize = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=scale,
zero_point=zero_point,
Expand All @@ -166,7 +178,7 @@ def get_anchors_for_fixed_quant_specs(
weights=[],
biases=[],
output=[
(node, qspec),
(node, qspec_or_fake_quantize),
],
)

Expand All @@ -190,7 +202,9 @@ def partition_types(self):


class AddmmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
Expand Down Expand Up @@ -365,7 +379,11 @@ def get_anchors(
ch_axis=0,
)

weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
Expand Down Expand Up @@ -399,7 +417,9 @@ def partition_types(self) -> list[OpOverload]:


class Conv2dPattern(ConvPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
Expand All @@ -426,7 +446,11 @@ def get_anchors(
ch_axis=0,
)

weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr = (
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
if self.is_qat
else PerChannelMinMaxObserver
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
Expand Down Expand Up @@ -563,7 +587,9 @@ def replacement_op(self):


class LinearPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
Expand Down Expand Up @@ -637,7 +663,9 @@ def partition_types(self):


class MmPattern(QuantizationPattern):
def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
Expand Down Expand Up @@ -802,7 +830,7 @@ def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 256.0, zero_point=-128
fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat
)


Expand All @@ -820,7 +848,7 @@ def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 256.0, zero_point=-128
fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat
)


Expand All @@ -838,7 +866,7 @@ def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 128.0, zero_point=0
fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat
)


Expand All @@ -856,7 +884,7 @@ def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 128.0, zero_point=0
fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat
)


Expand Down Expand Up @@ -884,7 +912,9 @@ class ActivationsConcatClusterPattern(QuantizationPattern):
"""

def __init__(self, neutron_quantizer):
def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
Expand Down
Loading
Loading