Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceDefaultQuantizer,
CadenceQuantizer,
CadenceW8A32MixedQuantizer,
)
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
Expand Down Expand Up @@ -59,6 +60,7 @@ def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
quantizer: Optional[CadenceQuantizer] = None,
) -> ExportedProgram:
"""
Trace the model with export and return an ExportedProgram.
Expand All @@ -73,6 +75,12 @@ def trace(
torch.ops.aten.rms_norm.default,
]

if isinstance(quantizer, CadenceW8A32MixedQuantizer):
ops_to_keep += [
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
]

program = trace_fn(
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
)
Expand All @@ -99,7 +107,7 @@ def prepare_pt2(
Returns a GraphModule with the prepared model.
"""

traced_program = trace(model, inputs, dump_graphs=dump_graphs)
traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs
)
Expand Down Expand Up @@ -184,7 +192,7 @@ def get_fake_quant_model(
# Make the model inference mode by calling model.eval()
model.eval()

program = trace(model, inputs, dump_graphs=dump_graphs)
program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)

if dump_graphs:
logging.info("Graph after trace:")
Expand Down
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,8 @@
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_w8a32_conv_out

- func: cadence::quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_w8a32_gru_out
25 changes: 25 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@
"quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
)

lib.define(
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor"
)

lib.define(
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
)


# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
aten_lib.define(
Expand Down Expand Up @@ -2646,3 +2655,19 @@ def quantized_w8a32_conv_meta(
channel_last=False,
)
return src.new_empty(output_size, dtype=src.dtype)


@register_fake("cadence::quantized_w8a32_gru")
def quantized_w8a32_gru_meta(
inputs: torch.Tensor,
hidden: torch.Tensor,
weights_inputs: torch.Tensor,
w_i_scale: float,
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)
46 changes: 46 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LinearPattern,
MatmulPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
ReluPattern0,
ReluPattern1,
Expand Down Expand Up @@ -528,6 +529,41 @@ def get_args_and_kwargs_mixed_w8a32_conv(
return args, kwargs


def get_args_and_kwargs_mixed_w8a32_gru(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
op_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet

assert len(dequants_weights) == 2
assert len(dequants_biases) == 2
w_i_scale = dequants_weights[0].args[1]
w_h_scale = dequants_weights[1].args[1]
b_i_scale = dequants_biases[0].args[1]
b_h_scale = dequants_biases[1].args[1]

args = (
other_inputs[0],
other_inputs[1],
weights_inputs[0],
w_i_scale,
weights_inputs[1],
w_h_scale,
bias_inputs[0],
b_i_scale,
bias_inputs[1],
b_h_scale,
)
kwargs = {}

return args, kwargs


class QuantFusion(ExportPass):
# pyre-ignore[2]: Parameter `patterns` has no type specified
def __init__(self, patterns) -> None:
Expand Down Expand Up @@ -707,6 +743,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_biases,
op_node,
)
elif isinstance(pattern, MixedW8A32GruPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_gru(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
op_node,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
57 changes: 57 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,60 @@ def get_anchors(

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_conv.default


class MixedW8A32GruPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.gru.input]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
gru_layer = fused_partition[0].nodes[-1]
if len(gru_layer.kwargs) > 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)

# Bail if input or states are not multiple of 4 (SIMD)
if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)

class Wrapper: # noqa: B903
def __init__(self, args, meta):
self.args = args
self.meta = meta

wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)

return (
PartitionAnchors(
inputs=[],
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
weights=[(wrapper, 0), (wrapper, 1)],
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
biases=[(wrapper, 2), (wrapper, 3)],
output=[],
others=[(gru_layer, 0), (gru_layer, 1)],
),
gru_layer,
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_gru.default
4 changes: 4 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LinearPattern,
MatmulPattern,
MixedW8A32ConvPattern,
MixedW8A32GruPattern,
MixedW8A32LinearPattern,
QuantizationPattern,
ReluPattern0,
Expand Down Expand Up @@ -325,6 +326,9 @@ def __init__(self) -> None:
quantizers.append(
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym)
)
super().__init__(quantizers)


Expand Down
Loading