diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index 3c0fdd8f987..be7cd427d6e 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -73,3 +73,67 @@ examples/qualcomm Please see this [README.md](../../examples/qualcomm/README.md). Further, an example build script is provided as [build.sh](scripts/build.sh). + +## Issues +If you want to address the problem encountered, it would be great to have reproduction information for indicating maintainers. Please also follow the [policy](../../CONTRIBUTING.md#issues) to emit issues. + +## Pull Requests +PRs are always welcome to help improve the codebase in a comprehensive manner. Before submitting changes, please apply: + +- **Check the Coding Style**:
+ Make sure your code follows the [style guides](../../CONTRIBUTING.md#coding-style) and passes the [lint checks](../../CONTRIBUTING.md#lintrunner). + +- **Add Unit Tests**:
+ Following is an example of adding test case after [creating new operator builder](builders/README.md), please navigate to `backends/qualcomm/tests` folder and put minimum example module in `model.py`. e.g.: + ```python + class IndexPut(torch.nn.Module): + ... + + # please insert implementation in alphabetical order + class LayerNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) + + def forward(self, x): + return self.layer_norm(x) + + + class LeakyReLUDefault(torch.nn.Module): + ... + ``` + Also extend sections `TestQNNFloatingPointOperator`, `TestQNNQuantizedOperator` in `test_qnn_delegate.py`. e.g.: + ```python + class TestQNNQuantizedOperator(TestQNN): + def test_qnn_backend_interpolate_nearest_2d(self): + ... + + # please insert it implementation alphabetical order + def test_qnn_backend_layer_norm(self): + module = LayerNorm() # noqa: F405 + sample_input = (torch.randn(196, 768),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_leaky_relu(self): + ... + ``` + +- **Verify Unit Test Results**:
+ ```bash + cd $PATH_TO_EXECUTORCH + # example usage of performing unit test + python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_layer_norm -s $DEVICE_SERIAL -m SM8650 -b build-android/ -a $PATH_TO_TEST_ARTIFACTS + ``` + The test graph is expected to have 1 delegated node with only placeholders / output nodes being left. Check the execution report for more information. + +- **Code Reviews**:
+ Please ping authors in Qualcomm AI Engine Direct related PRs for reviewing, possible candidates are listed below: + - [chiwwang](https://github.com/chiwwang) + - [shewu-quic](https://github.com/shewu-quic) + - [chunit-quic](https://github.com/chunit-quic) + - [winskuo-quic](https://github.com/winskuo-quic) + - [chuntl](https://github.com/chuntl) + - [haowhsu-quic](https://github.com/haowhsu-quic) + +Thanks again for your contribution! diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md new file mode 100644 index 00000000000..a81df0d6def --- /dev/null +++ b/backends/qualcomm/builders/README.md @@ -0,0 +1,361 @@ +# Contribution for More Operators +Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently. + +## Sections +* [References](#references) +* [Getting Started](#getting-started) + * [Identify Unsupported Operator](#identify-unsupported-operator) + * [Check Operator Spec](#check-operator-spec) + * [Implementation](#implementation) + * [Quantizer Annotation](#quantizer-annotation) +* [Issues](#issues) +* [Pull Requests](#pull-requests) + +## References +### Qualcomm AI Engine Direct +- [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html) +- [Supported Operators in Backends](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/operations.html#backend-supplements) + +### PyTorch +- [torch.nn Operator Definitions](https://pytorch.org/docs/stable/nn.html) +- [torch.nn.functional Operator Definitions](https://pytorch.org/docs/stable/nn.functional.html) +- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) + +## Getting Started +### Identify Unsupported Operator +Consider we're enabling following model: +```python +class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) + self.linear = torch.nn.Linear(768, 100) + + def forward(self, x): + return self.linear(self.layer_norm(x)) +``` +At the time we try to lower it with Qualcomm backend: +```python +from excutorch.examples.qualcomm.utils import build_executorch_binary + +build_executorch_binary( + model=MyModel(), + inputs=(torch.randn(200, 768),), + soc_model="SM8650" + file_name="my_model", + dataset=None, +) +``` +Assume there is no `torch.nn.LayerNorm` support, you should see the following error logs: +```bash +File "/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 77, in is_node_supported + op_wrapper = self.node_visitors[node.target.__name__].define_node( +KeyError: 'aten.native_layer_norm.default' +``` +This log comes straight to the point, there is no suitable conversion for delegating torch operator to Qualcomm AI Engine Direct. Where the `node_visitors` is a dictionary which maps operator target name with its implementation callback. The goal of this tutorial aims for helping you register the missing one.
+The very first step is to locate which operator type are we going to support. Sometimes the target name of operator might be obscure, following snippet could help you trace back by its call stack: +```python +from executorch.backends.qualcomm.utils.utils import capture_program + +prog = capture_program(MyModel(), (torch.randn(200, 768),)) +for node in prog.exported_program.graph.nodes: + if node.op == "call_function" and node.target.__name__ == 'aten.native_layer_norm.default': + print(node.meta["source_fn_stack"]) +``` +It will provide more hint to the source PyTorch layer where the missing operator maps to: +```bash +[('l__self___layer_norm', )] +``` + +### Check Operator Spec +- **Qualcomm AI Engine Direct**:
+ You could collect information of `LayerNorm`'s IO via documents mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct): + * inputs + - in[0] - input activation / required + - in[1] - gamma / optional + - in[2] - beta / optional + * parameters + - "epsilon" / optional + - "axes" / required + * outputs + - out[0] - output activation / required + + The required tensors must be provided for no default values were given inside QNN runtime, The order of IOs (`input activation`, `gamma`, `beta`) matters compared to parameters (`epsilon`, `axes`) who are recognized by literal value: + ```c + typedef struct { + /// A human-readable name for the operation instance. + const char* name; + /// The name of the operation package to which this operation's type belongs. + const char* packageName; + /// The name of operation type (e.g. Conv2D). + const char* typeName; + /// The number of static parameters provided in the params array. + uint32_t numOfParams; + /// Array of operation parameters. + Qnn_Param_t* params; + /// The number of input tensors. + uint32_t numOfInputs; + /// Array of input tensors. + Qnn_Tensor_t* inputTensors; + /// The number of output tensors. + uint32_t numOfOutputs; + /// Array of output tensors. + Qnn_Tensor_t* outputTensors; + } Qnn_OpConfigV1_t; + ``` + This is a data structure used to check operator validity in QNN SDK. Inside validation process, tensors are retrieved sequentially and passed through a series of spec examinations while parameters are matched by their names: + ```c + typedef struct { + /// Parameter type: scalar or tensor + Qnn_ParamType_t paramType; + /// Name of the parameter + const char* name; + + union UNNAMED { + /// Scalar parameter specification + Qnn_Scalar_t scalarParam; + /// Tensor parameter specification; tensors referred to must be STATIC. + Qnn_Tensor_t tensorParam; + }; + } Qnn_Param_t; + ``` + The name value equals to the parameter name described in [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html), there are `epsilon`, `axes` for `LayerNorm` case.
+ + If you find it hard to correlate missing operator with documentation, this [table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/SupportedOps.html) might be helpful for searching. In some cases, an exact match may not exist. Consider seeking for a math equivalent approach or notify maintainer for further analysis. + +- **PyTorch**:
+ We could also read the IO spec from [function declaration](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/layer_norm.cpp) mentioned in [PyTorch Documentation](#pytorch): + * inputs + - in[0] - input activation / required + - in[1] - normalized_shape / required + - in[2] - weight_opt / optional + - in[3] - bias_opt / optional + - in[4] - eps / required + + Through comparing the [equation](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html), we could sort out the relevance of arguments (`gamma` / `beta` / `epsilon`) inside Qualcomm manual to PyTorch (`weight_opt` / `bias_opt` / `eps`). The unmatched parameter `axes` will have more discussions in the [implementation](#implementation) part. + +### Implementation +Let's start with adding new definition in `qnn_constant.py` for `LayerNorm` operator. +```python +@dataclass(init=False, frozen=True) +class OpHardSwish: + ... + +# please insert it in alphabetically order +@dataclass(init=False, frozen=True) +class OpLayerNorm: + op_name: str = "LayerNorm" + param_epsilon = "epsilon" + param_axes = "axes" + + +@dataclass(init=False, frozen=True) +class OpLogSoftmax: + ... +``` +The conventions are: +- op_name: string describing the operator +- params_xxx: string for consumed parameters + +The content should have exact match with literal values mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct) or `QnnOpDef.h` under `$QNN_SDK_ROOT/include/QNN/`: +```c +#define QNN_OP_LAYER_NORM "LayerNorm" +#define QNN_OP_LAYER_NORM_PARAM_EPSILON "epsilon" +#define QNN_OP_LAYER_NORM_PARAM_AXES "axes" +``` + +Next, create a new file with name in snake case format (e.g. `op_layer_norm.py`) and import required modules (please check comments for getting the ideas of usage): +```python +# pybind interface for invoking QNN APIs +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +# tensors or other numerics will be shipped in numpy format +import numpy as np +import torch +# common keywords of Qualcomm backend +from executorch.backends.qualcomm.utils.constants import QCOM_DATA +# op builder will inherit NodeVisitor and have its own implementation +# register_node_visitor for book-keeping the dictionary of target name v.s. callback +from .node_visitor import NodeVisitor, register_node_visitor +# the definitions required to build operator in QNN +from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW +# utility to get parameter value when creating tensor in QNN +from .utils import get_parameter +``` +Start with function declaration as: +```python +@register_node_visitor +class LayerNormVisitor(NodeVisitor): + target = ["aten.native_layer_norm.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: +``` +It's mandatory to have `target` member in list form, since there would have multiple targets map to the same implementation. e.g. `aten.leaky_relu.default`, `aten.prelu.default` have similar equations but only differ in negative slope.
+The `nodes_to_wrappers` is a dictionary maintaining relationship between graph node and its output tensor. `nodes_to_wrappers` acts as an memo for not creating tensor objects to nodes that have already been traversed.
+ +Now, we can start to fill in function body step by step: +1. Define input activation tensors: + ```python + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + ``` + Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.
+ The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.
+ The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.
+ And yet, there are arguments worth for addressing more: + - **node**: current graph node + - **tensor**: torch tensor emitted by node + - **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters + - **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN) + - **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly + - **node_name**: (optional) tensor name for user to specify + - **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object + +2. Define input gamma / beta tensors: + ```python + weight_node = node.args[2] + weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor_wrapper = self.define_tensor( + weight_node, + weight_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + + bias_node = node.args[3] + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + ``` + The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property. + +3. Define parameters: + ```python + normalized_shapes = node.args[1] + if len(normalized_shapes) != 1: + print("QNN only supports normalized output with rank 1") + return + + axes = [len(input_tensor.shape) - 1] + axes_shape = [len(axes)] + epsilon = node.args[4] + ``` + Here you can see the constraint introduced by Qualcomm AI Engine Direct. Unlike PyTorch's LayerNorm operator, QNN can only normalize input into 1-D tensor. Therefore we will have log to remind user and return the program directly, this gesture will be considered as validation failure in partitioner and will fallback this operator to CPU.
+ When passing tensor type parameters via pybind interface, it's also required to ship extra information like tensor shape in list form. e.g. `axes_shape = [len(axes)]`. More details will be provided in coming steps. + +4. Define output tensor: + ```python + output_tensor = self.get_tensor(node, node, 0) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + ``` + Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method. + +5. Generate operator object in QNN graph: + ```python + layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpLayerNorm.op_name, + ) + ``` + +6. Pass IO tensors to operator object: + ```python + layer_norm_op.AddInputTensors( + [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] + ) + layer_norm_op.AddOutputTensors([output_tensor_wrapper]) + ``` + The IO tensor objects created before are gathered up and shipped to operator object. + +7. Pass parameters to operator object: + ```python + layer_norm_op.AddScalarParam( + OpLayerNorm.param_epsilon, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(epsilon)}, + ) + layer_norm_op.AddTensorParam( + OpLayerNorm.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(axis_shape), + axis_shape, + np.array(axis, dtype=np.uint32), + True, + ) + ``` + By checking the `Shape` property of parameter in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct), it should be clear which API to be used. e.g.: + - "epsilon" > __Shape__: scalar + - "axes" > __Shape__: 1D of shape[M] + + The function signature of AddScalarParam is: + - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual + - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc. + - **attr**: dictionary for shipping data, currently only `QCOM_DATA` key is used + + The function signature of AddTensorParam is: + - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual + - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc. + - **rank**: dimensions of tensor + - **dims**: shape of tensor + - **data**: tesnor data + - **copy_data**: user should specify to True for constant parameters + +8. Last, return operator object for partitioner to conduct validation: + ```python + return layer_norm_op + ``` + Also update the `__init__.py` for `register_node_visitor` to work properly: + ```python + from . import ( + ... + op_index_put, + # please insert codes in alphabetical order + op_layer_norm, + op_linear, + ... + ) + + __all__ = [ + ... + op_index_put, + # please insert codes in alphabetical order + op_layer_norm, + op_linear, + ... + ] + ``` + +### Quantizer Annotation +The operator now should be functional for Qualcomm backends. For operator to work in fixed-precision, we should also make `QnnQuantizer` to correctly insert observers for recording calibrated encodings. Please read more on the [Quantization Annotation Tutorial](../quantizer//README.md). + +## Issues +Please refer to the [issue section](../README.md#issues) for more information. + +## Pull Requests +Please refer to the [PR section](../README.md#pull-requests) for more information. diff --git a/backends/qualcomm/quantizer/README.md b/backends/qualcomm/quantizer/README.md new file mode 100644 index 00000000000..6870ecc76ac --- /dev/null +++ b/backends/qualcomm/quantizer/README.md @@ -0,0 +1,189 @@ +# Contribution for Operator Annotation +Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of annotating an operator in `QnnQuantizer` to unblock yourself and land pull requests more efficiently. + +## Sections +* [References](#references) +* [Getting Started](#getting-started) +* [Issues](#issues) +* [Pull Requests](#pull-requests) + +## References +### Qualcomm AI Engine Direct +- [Operator Definitions for HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html) + +### PyTorch +- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) + +## Getting Started +Before extending operator for quantization annotation, please make sure the operator builder has been well-implemented (learn more on this [tutorial](../builders/README.md)). +### Behavior of Annotation +In order to conduct PTQ for floating point precision graph, observers are required to be inserted after each graph nodes. The observed numeric range will go through different algorithms and return statistics of `scale`, `offset` to represent data in fixed point.

+**Stages could be shown as**: +- Floating point `nn.Module` after `torch.export.export` + ```mermaid + flowchart TB + input & kernel & bias --> id1(convolution) --> output + ``` + +- Inserting observers for inspecting numeric range + ```mermaid + flowchart TB + input --> id2(input_act_obs) --> id1(convolution) --> id3(output_act_obs) --> output + kernel --> id4(weight_obs) --> id1(convolution) + bias --> id5(bias_obs) --> id1(convolution) + ``` + +- Cascade QDQ pairs after landing encodings + ```mermaid + flowchart TB + input --> id2(Q_i) --> id3(DQ_i) --> id1(convolution) --> id4(Q_o) --> id5(DQ_o) --> output + kernel --> id6(Q_k) --> id7(DQ_k) --> id1(convolution) + bias --> id8(Q_b) --> id9(DQ_b) --> id1(convolution) + ``` +Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilies. + +### Register Annotation via Operator Type +Let's start with hooking callback for designated operator target: +```python +def register_annotator(ops: List[OpOverload]): + def decorator(annotator: Callable): + for op in ops: + OP_ANNOTATOR[op] = annotator + + return decorator +``` +The `register_annotator` decorator provides a convenient way to attach your own annotation logic, which requires list of operator type as its input argument.
For example, the torch activation functions have `copy`, `in-place` implementation with small difference appears in naming (an extra `_` postfix), which will map to the same [Core ATen](https://pytorch.org/docs/stable/torch.compiler_ir.html) operators after `to_edge`: +```python +@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) +``` +Where `torch.ops.aten.relu.default` / `torch.ops.aten.relu_.default` map to `copy` / `in-place` version and both will be converted into `torch.ops.aten.relu.default` ultimately.

+ +The function signature is defined as follow with two arguments: +```python +def annotate_xxx(node: Node, quantization_config: QuantizationConfig) -> None: +``` +- __node__: graph node required to be observed +- __quantization_config__: data structure describing quantization configurations for IO activation / weight / bias + +### Example of Conv2d Annotation +Conv2d accepts up to three input tensors: `input activation`, `kernel`, `bias`. There are constraints imposed by [Qualcomm AI Engine Direct Manual](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html#conv2d).
+Take 8-bit fixed point as example: +- __weight__: must be symmetrically quantized if per-channel observer is applied +- __bias__: must have `QNN_DATATYPE_SFIXED_POINT_32` and be symmetrically quantized with expected encoding `scales = weight.scales * input.scale`, `offset = 0` if per-channel observer is applied. + +Let's look at the simplified per-channel quantization configuration used in `QnnQuantizer`: +```python +def ptq_per_channel_quant_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + ... + act_quantization_spec = QuantizationSpec( + dtype=act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(weight_dtype).min + 1, + quant_max=torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config +``` +Here we choose `torch.uint8` + `MinMaxObserver` for better converage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.
+ +Now, we can start to fill in the function body: +- Register annotator + ```python + @register_annotator( + [ + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv_transpose2d.input, + ] + ) + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + ``` + There are multiple targets expected to meet our annotation criteria, it's encouraged to do so for code reuse. + +- Define map of input quantization spec + ```python + if _is_annotated([node]): + return + + input_qspec_map = {} + + # annotate input activation + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + # annotate kernel + kernel = node.args[1] + input_qspec_map[kernel] = quantization_config.weight + + # annotate bias + if len(node.args) > 2: + bias = node.args[2] + input_qspec_map[bias] = quantization_config.bias(node) + ``` + We first check if current graph node has been annotated. If not, an `input_qspec_map` dictionary required by PyTorch framework will be declared for providing mapping between graph nodes and their configurations.
+ The parameters' order could be found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Convolution.cpp) mentioned in [ATen Operator Definitions](#pytorch). Since bias node is optional, the implementation will invoke `_derived_bias_quant_spec` to calculate the per-channel bias encoding only if it exists. + +- Update node's meta with framework compatible data structure + ```python + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + ``` + After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`QUANT_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers. + +### Common Annotators +For operators without extra parameters to be observed, there are pre-defined annotation method for convenience: +- Single in single out operators, e.g.: + ```python + @register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) + def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + ``` + +- Binary in single out operators, e.g.: + ```python + @register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) + def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + ``` + +- Shared encodings between input / output, e.g.:
+ ```python + # For operators without arithmetical function, IOs are expected to own the same encodings. + @register_annotator([torch.ops.aten.transpose.int]) + def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) + ``` + This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke `annotate_single_in_single_out` again (this path should be less likely). + +## Issues +Please refer to the [issue section](../README.md#issues) for more information. + +## Pull Requests +Please refer to the [PR section](../README.md#pull-requests) for more information.