Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/scripts/setup-samsung-linux-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ download_ai_lite_core() {
API_BASE="https://soc-developer.semiconductor.samsung.com/api/v1/resource/ai-litecore/download"
API_KEY=$SAMSUNG_AI_LITECORE_KEY

VERSION="0.5"
VERSION="0.7"
OS_NAME="Ubuntu 22.04"
OUT_FILE="/tmp/exynos-ai-litecore-v${VERSION}.tar.gz"
TARGET_PATH="/tmp/exynos_ai_lite_core"
Expand Down Expand Up @@ -62,7 +62,7 @@ install_enn_backend() {
export PYTHONPATH=${PYTHONPATH:-}:${EXECUTORCH_ROOT}/..
}

AI_LITE_CORE_VERSION=0.5.0
AI_LITE_CORE_VERSION=0.7.0

download_ai_lite_core ${AI_LITE_CORE_VERSION}
install_enn_backend
6 changes: 6 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,12 @@ jobs:
python -m executorch.examples.samsung.aot_compiler --model_name=$model -c E9955
done
# Test quant models
model_scripts="deeplab_v3 edsr inception_v3 inception_v4 mobilenet_v2 mobilenet_v3 resnet18 resnet50 vit wav2letter"
for m_script in $model_scripts; do
python -m executorch.examples.samsung.scripts.${m_script} -c e9955 -p A8W8
done
# Test ops
python -m unittest discover -s backends/samsung/test/ops -p "test_*.py"
Expand Down
201 changes: 201 additions & 0 deletions backends/samsung/_passes/annotate_qparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import Any, Dict, List, Optional

import torch
from executorch.backends.samsung.utils.constants import QuantConstants
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import get_buffer
from torch.export import ExportedProgram
from torch.fx import GraphModule, Node


class AnnotateQparamsPass(ExportPass):
"""This parse is to add quantize properties to node need to be quantized.

Annotate Quant params:
For src_node->Q->DQ->..., we will add the quant params from Q->DQ node
to the src_node

Annotate Requantize:
For src_node->Q->DQ->Q->DQ->..., if the multiple Q->DQ contains
different quant params, we will mark the src_node as need requantize,
and add Q->DQ after removing all the Q->DQs.
"""

propagate_nodes = {
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.squeeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dim,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.concat.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.expand_copy.default,
}

def __init__(self, edge_program: ExportedProgram):
super().__init__()
self.edge_program = edge_program

def _get_last_dqs(self, node: Node) -> List[Node]:
r"""From one Q-DQ node, find the last DQs in the quantization node chain.


need to consider such case:
/--Q-DQ-node1
node->Q->DQ--node-node2
\--Q-DQ-node3
This is a dfs implemention, so result will keep sorted
Args:
node (Node): Search DQ from this node.

Returns:
List[Node]: list of DQ node by original sequence
"""

def _impl(node: Node, res_list: List[Node]):
if (
node.target not in QuantConstants.QUANT_OPS_KEY_MAP
and node.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
):
return
for user in node.users.keys():
if (
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
):
res_list.append(node)
else:
_impl(user, res_list)

res_list: List[Node] = []
for user in node.users:
_impl(user, res_list)
return res_list

def _propagate_quant_params(self, node: Node):
assert (
quantize_attrs := node.meta.get("quantize_attrs")
), "Must be annotated node."
requantize_map: Dict[Node, Node] = node.meta.get("requantize", {})
while node.users:
if len(node.users) != 1:
break
user = list(node.users.keys())[0]
if (
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
):
break
node = user
# Case1: ...-q-dq(cur)-propagate_node-node(not d-dq)
# Case2: propagate_node(propagateed)-propagate_node-node(not q-dq)
for idx, user in enumerate(node.users.keys()):
# For the branch who need to be requantized, we propagate the requantize params
user_attrs = requantize_map.get(idx, quantize_attrs)
if user.target not in self.propagate_nodes:
continue
if len(user.users) == 1:
# Possibily no need for checking len(users)>1
user_of_user = list(user.users)[0]
# node-q-dq-propagate-q-dq not need for propagatey
if (
user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP
or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP
):
continue
# propagate quant for node-q-dq-propagate_node-node(not qdq)
user.meta["quantize_attrs"] = user_attrs
self._propagate_quant_params(user)

def _annotate_requantize(self, node: Node):
assert (
ori_quant_attrs := node.meta.get("quantize_attrs")
), "No quant parameters found"
list_for_requantize = self._get_last_dqs(node)
node.meta["requantize"] = node.meta.get("requantize", {})

# We use index to mark the output to be requantized
# Because user obj and name may change when we requantize them.

def _check_same(requant_obj, ori_obj) -> bool:
if type(requant_obj) != type(ori_obj): # noqa E721
# We need actually same type here.
return False
if not isinstance(requant_obj, torch.Tensor):
return requant_obj == ori_obj
if requant_obj.shape != ori_obj.shape:
return False
return bool((requant_obj == ori_obj).all())

requantize_map: Dict[int, Dict] = node.meta["requantize"]
for idx, dq in enumerate(list_for_requantize):
q = dq.all_input_nodes[0]
if q.target not in QuantConstants.QUANT_OPS_KEY_MAP:
continue
key_map = QuantConstants.DEQUANT_OPS_KEY_MAP[dq.target]
requantize_attrs = self.get_quant_attrs(q, key_map)
if not all(
_check_same(ori_quant_attrs[key], requantize_attrs[key])
for key in key_map.values()
):
requantize_map[idx] = requantize_attrs

def _annotate(self, graph_module: GraphModule):
for node in graph_module.graph.nodes:
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
if not key_map:
continue
source_node = node.args[0]
if source_node.target in (
*QuantConstants.QUANT_OPS_KEY_MAP,
*QuantConstants.DEQUANT_OPS_KEY_MAP,
):
# Currently, don't add quant info for d_qd node here.
continue
elif source_node.target == operator.getitem:
source_node = source_node.args[0]
quant_attrs = self.get_quant_attrs(node, key_map)
source_node.meta["quantize_attrs"] = quant_attrs
self._annotate_requantize(source_node)
self._propagate_quant_params(source_node)

def call(self, graph_module: GraphModule):
self._annotate(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)

def get_quant_attrs(
self, quant_node: torch.fx.Node, key_map: Optional[Dict] = None
) -> Dict[str, Any]:
quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments]
quant_attrs = dict.fromkeys(quant_attr_keys)
for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]):
# For channel-wise quantization, params are stored by buffer nodes.
if isinstance(attr, torch.fx.Node):
attr = get_buffer(self.edge_program, attr)
quant_attrs[key] = attr
quant_attrs["target"] = quant_node.target
if key_map is None:
return quant_attrs
miss_attrs = []
for aten_attr, snc_attr in key_map.items():
if aten_attr not in quant_attrs:
miss_attrs.append(aten_attr)
continue
attr = quant_attrs[aten_attr]
quant_attrs.pop(aten_attr)
quant_attrs[snc_attr] = attr
assert (
not miss_attrs
), f"Miss quant attrs {miss_attrs} for node {quant_node.name}"
return quant_attrs
65 changes: 65 additions & 0 deletions backends/samsung/_passes/annotate_scalar_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025 Samsung Electronics Co. LTD
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
from executorch.backends.samsung.utils.constants import QuantConstants
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.export import ExportedProgram


class AnnotateScalarParametersPass(ExportPass):
"""
Need to add quantization parameters for scalars for some ops
Ifm(Quantized)------TargetOP---
Scalar(Non-Quant)---/
Notice: Such scalars are converted to tensor node by default pass
"""

TARGET_OPS = {
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.div.Tensor,
}

def __init__(self, edge_program: ExportedProgram):
super().__init__()
self.edge_program = edge_program

def annotate(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
continue
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
for input_arg in node.all_input_nodes:
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
self.edge_program, input_arg
):
continue
else:
tensor = get_param_tensor(self.edge_program, input_arg)
if not tensor.shape:
qparams = {
QuantConstants.QUANT_KEY.scale: float(tensor),
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
torch_quant_dtype
).max,
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
torch_quant_dtype
).min,
QuantConstants.QUANT_KEY.zero_point: 0,
}
input_arg.meta["quantize_attrs"] = qparams

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
self.annotate(graph_module)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Loading
Loading