diff --git a/.ci/scripts/setup-samsung-linux-deps.sh b/.ci/scripts/setup-samsung-linux-deps.sh index 434587975ab..c1f2912713b 100644 --- a/.ci/scripts/setup-samsung-linux-deps.sh +++ b/.ci/scripts/setup-samsung-linux-deps.sh @@ -13,7 +13,7 @@ download_ai_lite_core() { API_BASE="https://soc-developer.semiconductor.samsung.com/api/v1/resource/ai-litecore/download" API_KEY=$SAMSUNG_AI_LITECORE_KEY - VERSION="0.5" + VERSION="0.7" OS_NAME="Ubuntu 22.04" OUT_FILE="/tmp/exynos-ai-litecore-v${VERSION}.tar.gz" TARGET_PATH="/tmp/exynos_ai_lite_core" @@ -62,7 +62,7 @@ install_enn_backend() { export PYTHONPATH=${PYTHONPATH:-}:${EXECUTORCH_ROOT}/.. } -AI_LITE_CORE_VERSION=0.5.0 +AI_LITE_CORE_VERSION=0.7.0 download_ai_lite_core ${AI_LITE_CORE_VERSION} install_enn_backend diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 11e005847e6..5b646cba9d1 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -935,6 +935,12 @@ jobs: python -m executorch.examples.samsung.aot_compiler --model_name=$model -c E9955 done + # Test quant models + model_scripts="deeplab_v3 edsr inception_v3 inception_v4 mobilenet_v2 mobilenet_v3 resnet18 resnet50 vit wav2letter" + for m_script in $model_scripts; do + python -m executorch.examples.samsung.scripts.${m_script} -c e9955 -p A8W8 + done + # Test ops python -m unittest discover -s backends/samsung/test/ops -p "test_*.py" diff --git a/backends/samsung/_passes/annotate_qparams.py b/backends/samsung/_passes/annotate_qparams.py new file mode 100644 index 00000000000..663d1fdf5fa --- /dev/null +++ b/backends/samsung/_passes/annotate_qparams.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Any, Dict, List, Optional + +import torch +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import get_buffer +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node + + +class AnnotateQparamsPass(ExportPass): + """This parse is to add quantize properties to node need to be quantized. + + Annotate Quant params: + For src_node->Q->DQ->..., we will add the quant params from Q->DQ node + to the src_node + + Annotate Requantize: + For src_node->Q->DQ->Q->DQ->..., if the multiple Q->DQ contains + different quant params, we will mark the src_node as need requantize, + and add Q->DQ after removing all the Q->DQs. + """ + + propagate_nodes = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.concat.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.expand_copy.default, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def _get_last_dqs(self, node: Node) -> List[Node]: + r"""From one Q-DQ node, find the last DQs in the quantization node chain. + + + need to consider such case: + /--Q-DQ-node1 + node->Q->DQ--node-node2 + \--Q-DQ-node3 + This is a dfs implemention, so result will keep sorted + Args: + node (Node): Search DQ from this node. + + Returns: + List[Node]: list of DQ node by original sequence + """ + + def _impl(node: Node, res_list: List[Node]): + if ( + node.target not in QuantConstants.QUANT_OPS_KEY_MAP + and node.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + return + for user in node.users.keys(): + if ( + user.target not in QuantConstants.QUANT_OPS_KEY_MAP + and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + res_list.append(node) + else: + _impl(user, res_list) + + res_list: List[Node] = [] + for user in node.users: + _impl(user, res_list) + return res_list + + def _propagate_quant_params(self, node: Node): + assert ( + quantize_attrs := node.meta.get("quantize_attrs") + ), "Must be annotated node." + requantize_map: Dict[Node, Node] = node.meta.get("requantize", {}) + while node.users: + if len(node.users) != 1: + break + user = list(node.users.keys())[0] + if ( + user.target not in QuantConstants.QUANT_OPS_KEY_MAP + and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + break + node = user + # Case1: ...-q-dq(cur)-propagate_node-node(not d-dq) + # Case2: propagate_node(propagateed)-propagate_node-node(not q-dq) + for idx, user in enumerate(node.users.keys()): + # For the branch who need to be requantized, we propagate the requantize params + user_attrs = requantize_map.get(idx, quantize_attrs) + if user.target not in self.propagate_nodes: + continue + if len(user.users) == 1: + # Possibily no need for checking len(users)>1 + user_of_user = list(user.users)[0] + # node-q-dq-propagate-q-dq not need for propagatey + if ( + user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP + or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + continue + # propagate quant for node-q-dq-propagate_node-node(not qdq) + user.meta["quantize_attrs"] = user_attrs + self._propagate_quant_params(user) + + def _annotate_requantize(self, node: Node): + assert ( + ori_quant_attrs := node.meta.get("quantize_attrs") + ), "No quant parameters found" + list_for_requantize = self._get_last_dqs(node) + node.meta["requantize"] = node.meta.get("requantize", {}) + + # We use index to mark the output to be requantized + # Because user obj and name may change when we requantize them. + + def _check_same(requant_obj, ori_obj) -> bool: + if type(requant_obj) != type(ori_obj): # noqa E721 + # We need actually same type here. + return False + if not isinstance(requant_obj, torch.Tensor): + return requant_obj == ori_obj + if requant_obj.shape != ori_obj.shape: + return False + return bool((requant_obj == ori_obj).all()) + + requantize_map: Dict[int, Dict] = node.meta["requantize"] + for idx, dq in enumerate(list_for_requantize): + q = dq.all_input_nodes[0] + if q.target not in QuantConstants.QUANT_OPS_KEY_MAP: + continue + key_map = QuantConstants.DEQUANT_OPS_KEY_MAP[dq.target] + requantize_attrs = self.get_quant_attrs(q, key_map) + if not all( + _check_same(ori_quant_attrs[key], requantize_attrs[key]) + for key in key_map.values() + ): + requantize_map[idx] = requantize_attrs + + def _annotate(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None) + if not key_map: + continue + source_node = node.args[0] + if source_node.target in ( + *QuantConstants.QUANT_OPS_KEY_MAP, + *QuantConstants.DEQUANT_OPS_KEY_MAP, + ): + # Currently, don't add quant info for d_qd node here. + continue + elif source_node.target == operator.getitem: + source_node = source_node.args[0] + quant_attrs = self.get_quant_attrs(node, key_map) + source_node.meta["quantize_attrs"] = quant_attrs + self._annotate_requantize(source_node) + self._propagate_quant_params(source_node) + + def call(self, graph_module: GraphModule): + self._annotate(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + def get_quant_attrs( + self, quant_node: torch.fx.Node, key_map: Optional[Dict] = None + ) -> Dict[str, Any]: + quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments] + quant_attrs = dict.fromkeys(quant_attr_keys) + for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]): + # For channel-wise quantization, params are stored by buffer nodes. + if isinstance(attr, torch.fx.Node): + attr = get_buffer(self.edge_program, attr) + quant_attrs[key] = attr + quant_attrs["target"] = quant_node.target + if key_map is None: + return quant_attrs + miss_attrs = [] + for aten_attr, snc_attr in key_map.items(): + if aten_attr not in quant_attrs: + miss_attrs.append(aten_attr) + continue + attr = quant_attrs[aten_attr] + quant_attrs.pop(aten_attr) + quant_attrs[snc_attr] = attr + assert ( + not miss_attrs + ), f"Miss quant attrs {miss_attrs} for node {quant_node.name}" + return quant_attrs diff --git a/backends/samsung/_passes/annotate_scalar_parameters.py b/backends/samsung/_passes/annotate_scalar_parameters.py new file mode 100644 index 00000000000..643685bdb25 --- /dev/null +++ b/backends/samsung/_passes/annotate_scalar_parameters.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.samsung.quantizer.quantizer import global_quant_info +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export import ExportedProgram + + +class AnnotateScalarParametersPass(ExportPass): + """ + Need to add quantization parameters for scalars for some ops + Ifm(Quantized)------TargetOP--- + Scalar(Non-Quant)---/ + Notice: Such scalars are converted to tensor node by default pass + """ + + TARGET_OPS = { + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.div.Tensor, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def annotate(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta: + continue + torch_quant_dtype = global_quant_info.weight_precison.torch_dtype + for input_arg in node.all_input_nodes: + if input_arg.op not in ("placeholder", "get_attr") or not is_param_node( + self.edge_program, input_arg + ): + continue + else: + tensor = get_param_tensor(self.edge_program, input_arg) + if not tensor.shape: + qparams = { + QuantConstants.QUANT_KEY.scale: float(tensor), + QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype, + QuantConstants.QUANT_KEY.quant_max: torch.iinfo( + torch_quant_dtype + ).max, + QuantConstants.QUANT_KEY.quant_min: torch.iinfo( + torch_quant_dtype + ).min, + QuantConstants.QUANT_KEY.zero_point: 0, + } + input_arg.meta["quantize_attrs"] = qparams + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + self.annotate(graph_module) + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/conv1d_to_conv2d.py b/backends/samsung/_passes/conv1d_to_conv2d.py index 57f1074b348..1b8782d956b 100644 --- a/backends/samsung/_passes/conv1d_to_conv2d.py +++ b/backends/samsung/_passes/conv1d_to_conv2d.py @@ -5,84 +5,93 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.transforms.utils import get_param_tensor from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from torch._export.utils import get_param class Conv1dToConv2d(ExportPass): - def __init__(self, edge_program: ExportedProgram): super().__init__() self.edge_program = edge_program + def update_kernel(self, weight_node: torch.Tensor): + # lifted tensor in tensor constant + weight_3d = get_param_tensor(self.edge_program, weight_node) + if param_name := self.edge_program.graph_signature.inputs_to_parameters.get( + weight_node.name + ): + new_weight_param = torch.nn.Parameter( + data=weight_3d.data.contiguous().unsqueeze(dim=-1), requires_grad=False + ) + self.edge_program.state_dict[param_name] = new_weight_param + elif tensor_name := self.edge_program.graph_signature.inputs_to_lifted_tensor_constants.get( + weight_node.name + ): + self.edge_program.constants[tensor_name] = torch.unsqueeze(weight_3d, -1) + else: + RuntimeError("Weight of 1d conv should be constant tensor or Parameter obj") + weight_node.meta["val"] = weight_node.meta["val"].data.unsqueeze(dim=-1) + def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph node_list = list(graph.nodes) for node in node_list: - if node.op == "call_function": - if node.target == exir_ops.edge.aten.convolution.default: - stride = list(node.args[3]) - if len(stride) != 1: - continue + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.convolution.default: + continue + stride = list(node.args[3]) + if len(stride) != 1: + continue - # convert 3dim weight to 4dim - weight_node = node.args[1] - weight_3dim = get_param(self.edge_program, weight_node) - weight_4dim = torch.nn.Parameter( - data=weight_3dim.data.contiguous().unsqueeze(dim=-1), - requires_grad=False, - ) - parameter_name = ( - self.edge_program.graph_signature.inputs_to_parameters[ - weight_node.name - ] - ) - self.edge_program.state_dict[parameter_name] = weight_4dim - weight_node.meta["val"] = weight_node.meta["val"].data.unsqueeze( - dim=-1 - ) + # convert 3dim weight to 4dim + weight_node = node.args[1] + self.update_kernel(weight_node) - # Extend stride, padding, and dilation - node.args = ( - node.args[0], - node.args[1], - node.args[2], - node.args[3] + [1], # stride - node.args[4] + [0], # padding - node.args[5] + [1], # dilation - node.args[6], - node.args[7], - node.args[8], - ) + # Extend stride, padding, and dilation + node.args = ( + node.args[0], + node.args[1], + node.args[2], + node.args[3] + [1], # stride + node.args[4] + [0], # padding + node.args[5] + [1], # dilation + node.args[6], + node.args[7], + node.args[8], + ) + # unsqueeze -> conv2d -> squeeze - # unsqueeze -> conv2d -> squeeze - with graph.inserting_before(node): - input_node = node.args[0] - unsqueeze_before = graph.create_node( - "call_function", exir_ops.edge.aten.unsqueeze_copy.default - ) - unsqueeze_before.args = ( - input_node, - -1, - ) - node.replace_input_with(input_node, unsqueeze_before) + with graph.inserting_before(node): + input_node = node.args[0] + prev_qparams = input_node.meta.get("quantize_attrs") + unsqueeze_before = graph.create_node( + "call_function", exir_ops.edge.aten.unsqueeze_copy.default + ) + unsqueeze_before.args = ( + input_node, + -1, + ) + node.replace_input_with(input_node, unsqueeze_before) - with graph.inserting_after(node): - squeeze_after = graph.create_node( - "call_function", exir_ops.edge.aten.squeeze_copy.dims - ) - squeeze_after.args = ( - node, - [-1], - ) - original_users = [ - user for user in node.users if user != squeeze_after - ] - for user in original_users: - user.replace_input_with(node, squeeze_after) + with graph.inserting_after(node): + squeeze_after = graph.create_node( + "call_function", exir_ops.edge.aten.squeeze_copy.dims + ) + squeeze_after.args = ( + node, + [-1], + ) + original_users = [user for user in node.users if user != squeeze_after] + for user in original_users: + user.replace_input_with(node, squeeze_after) + if quant_attr := node.meta.get("quantize_attrs"): + squeeze_after.meta["quantize_attrs"] = quant_attr + if prev_qparams is not None: + unsqueeze_before.meta["quantize_attrs"] = prev_qparams graph_module.recompile() - graph_module = super().call(graph_module).graph_module + _ = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/fold_qdq.py b/backends/samsung/_passes/fold_qdq.py new file mode 100644 index 00000000000..c6f3699ece7 --- /dev/null +++ b/backends/samsung/_passes/fold_qdq.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +class FoldQDQPass(ExportPass): + def __init__(self): + super().__init__() + + def _fold( + self, + graph_module: GraphModule, + ): + for node in graph_module.graph.nodes: + if node.target not in ( + *QuantConstants.QUANT_OPS_KEY_MAP.keys(), + *QuantConstants.DEQUANT_OPS_KEY_MAP.keys(), + ): + continue + for user in [user for user in node.users.keys()]: # noqa: C416 + user.replace_input_with(node, node.args[0]) + graph_module.graph.erase_node(node) + + def call(self, graph_module: GraphModule): + self._fold(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/fuse_conv_act.py b/backends/samsung/_passes/fuse_conv_act.py new file mode 100644 index 00000000000..c034c98bb14 --- /dev/null +++ b/backends/samsung/_passes/fuse_conv_act.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +def map_hardtan_relux(tanhnode: torch.fx.node.Node) -> Optional[str]: + assert ( + tanhnode.target == exir_ops.edge.aten.hardtanh.default + ), "Must be a hardtanh node" + if not tanhnode.args[1] == 0.0: + return None + if tanhnode.args[2] == 6.0: + return "RELU6" + return None + + +class FuseConvActPass(ExportPass): + TARGET_ACTS_MAP = { + exir_ops.edge.aten.relu.default: (lambda x: "RELU"), + exir_ops.edge.aten.relu_.default: (lambda x: "RELU"), + exir_ops.edge.aten.relu6.default: (lambda x: "RELU6"), + exir_ops.edge.aten.relu6_.default: (lambda x: "RELU6"), + exir_ops.edge.aten.hardtanh.default: map_hardtan_relux, + exir_ops.edge.aten.hardtanh_.default: map_hardtan_relux, + } + + def _fuse( + self, + graph_module: GraphModule, + ): + for target_conv, target_act in self.get_target_conv_act(graph_module): + assert ( + act_name := self.TARGET_ACTS_MAP.get(target_act.target)(target_act) + ), f"Not supported {target_act.name} now." + target_conv.meta["activation"] = act_name + if "quantize_attrs" in target_act.meta: + target_conv.meta["quantize_attrs"] = target_act.meta["quantize_attrs"] + + # If we merge the real out activation to conv, the conv should be the real out + if "real_out" in target_act.meta: + target_conv.meta["real_out"] = target_act.meta["real_out"] + for user in [user for user in target_act.users.keys()]: # noqa: C416 + user.replace_input_with(target_act, target_conv) + graph_module.graph.erase_node(target_act) + + def get_target_conv_act(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + if node.target != exir_ops.edge.aten.convolution.default: + continue + if len(node.users) != 1: + # Such cases couldn't be conv + act + continue + act_node = list(node.users.keys())[0] + if act_node.target not in self.TARGET_ACTS_MAP: + continue + if "quantize_attrs" in node.meta: + # If the conv's output is quantized + # We do not fuse them + continue + yield node, act_node + + def call(self, graph_module: GraphModule): + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/insert_qdq.py b/backends/samsung/_passes/insert_qdq.py new file mode 100644 index 00000000000..a59b011ac4b --- /dev/null +++ b/backends/samsung/_passes/insert_qdq.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict + +import torch +from executorch.backends.samsung._passes.utils import none_quant_tensor_quant_meta +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.samsung.utils.utils import is_graph_input, is_graph_output + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export import ExportedProgram +from torch.fx import GraphModule + + +class QType(Enum): + Quant = 0 + Dequant = 1 + + +class InsertQDQPass(ExportPass): + QDQ_MAP = { + # per tensor + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + # per channel + exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def _create_qdq_node( + self, + graph_module: GraphModule, + qtype: QType, + input_node: torch.fx.Node, + quant_attrs: Dict[str, Any], + ) -> torch.fx.Node: + assert (target := quant_attrs.get("target")), "" + new_node_args = [input_node] + new_node_meta_val = input_node.meta["val"] + new_node_quant_attrs = {} + if qtype == QType.Dequant: + target = self.QDQ_MAP[target] + else: + # For input node, we should set the val type as quant type + key = QuantConstants.QUANT_KEY.quant_dtype + new_node_meta_val = new_node_meta_val.to(quant_attrs[key]) + new_node_quant_attrs.update(quant_attrs) + + for arg in target._schema.arguments[1:]: + name = arg.name + if name == "out_dtype": + continue + if qtype == QType.Quant: + key = QuantConstants.QUANT_OPS_KEY_MAP[target].get(name, name) + else: + key = QuantConstants.DEQUANT_OPS_KEY_MAP[target].get(name, name) + arg_value = quant_attrs[key] + if isinstance(arg.type, torch.Tensor) and ( + isinstance(arg_value, int) or isinstance(arg_value, float) + ): + arg_value = torch.Tensor(arg_value) + new_node_args.append(arg_value) + + new_node = graph_module.graph.create_node( + "call_function", target, tuple(new_node_args) + ) + if new_node_quant_attrs: + new_node.meta["quantize_attrs"] = new_node_quant_attrs + else: + new_node.meta["quantize_attrs"] = { + QuantConstants.QUANT_KEY.quant_dtype: torch.float32, + QuantConstants.QUANT_KEY.scale: [1.0], + QuantConstants.QUANT_KEY.zero_point: [0], + } + new_node.meta["val"] = new_node_meta_val + return new_node + + def _add_dq_after(self, graph_module: GraphModule, node: torch.fx.Node): + if not (quant_attrs := node.meta.get("quantize_attrs")): + return + with graph_module.graph.inserting_after(node): + new_node = self._create_qdq_node( + graph_module, QType.Dequant, node, quant_attrs + ) + users = [user for user in node.users.keys() if (user.op == "output")] + for user in users: + user.replace_input_with(node, new_node) + + def _add_q_after(self, graph_module: GraphModule, node: torch.fx.Node): + # In node don't need quant attrs after insert new quantize node. + if not (quant_attrs := node.meta.pop("quantize_attrs", None)): + return + node.meta["quantize_attrs"] = none_quant_tensor_quant_meta() + with graph_module.graph.inserting_after(node): + users = list(node.users.keys()) + new_node = self._create_qdq_node( + graph_module, QType.Quant, node, quant_attrs + ) + for user in users: + if user.target not in QuantConstants.QUANT_OPS_KEY_MAP: + user.replace_input_with(node, new_node) + + def _add_q_before( + self, + graph_module: GraphModule, + node: torch.fx.Node, + from_node: torch.fx.Node, + quantize_attrs: Dict, + ): + with graph_module.graph.inserting_before(node): + new_quant_node = self._create_qdq_node( + graph_module, QType.Quant, from_node, quantize_attrs + ) + node.replace_input_with(from_node, new_quant_node) + return new_quant_node + + def _add_dq_before( + self, + graph_module: GraphModule, + node: torch.fx.Node, + from_node: torch.fx.Node, + quantize_attrs: Dict, + ): + with graph_module.graph.inserting_before(node): + new_dequant_node = self._create_qdq_node( + graph_module, QType.Dequant, from_node, quantize_attrs + ) + node.replace_input_with(from_node, new_dequant_node) + return new_dequant_node + + def _add_qdq_for_requantize(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + requant_map: Dict[int, Dict] = node.meta.get("requantize") + if requant_map is None: + continue + assert (ori_quant_attrs := node.meta.get("quantize_attrs")) + usr_list = list(node.users.keys()) + for user_idx, requant_params in requant_map.items(): + user = usr_list[user_idx] + q_node = self._add_q_before(graph_module, user, node, requant_params) + _ = self._add_dq_before(graph_module, q_node, node, ori_quant_attrs) + + def _add_qdq(self, graph_module: GraphModule): + for node in list(graph_module.graph.nodes): + if is_graph_input(self.edge_program, node): + self._add_q_after(graph_module, node) + elif is_graph_output(node): + self._add_dq_after(graph_module, node) + + def call(self, graph_module: GraphModule): + self._add_qdq(graph_module) + self._add_qdq_for_requantize(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/remove_useless_ops.py b/backends/samsung/_passes/remove_useless_ops.py new file mode 100644 index 00000000000..c88a2d4a5d8 --- /dev/null +++ b/backends/samsung/_passes/remove_useless_ops.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +class RemoveUselessOpPass(ExportPass): + # such ops should be single-in and single-out + USELESS_OP_SET = { + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.alias.default, + exir_ops.edge.aten.lift_fresh_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + } + + def __init__(self): + super().__init__() + + def gen_pattern_as_strided_copy(self, graph_module: GraphModule): + for node in list(graph_module.graph.nodes): # noqa: C416 + if node.target != exir_ops.edge.aten.mean.dim: + continue + if len(node.users) != 1: + continue + successor = list(node.users.keys())[0] + if successor.target != exir_ops.edge.aten.as_strided_copy.default: + continue + is_pattern = True + count = 0 + for i, stride in enumerate(successor.args[2]): + if stride < node.meta["val"].size()[i]: + if stride == 1: + count += 1 + else: + is_pattern = False + break + if count >= 2: + is_pattern = False + break + if is_pattern: + yield successor + + def _fold_as_strided_copy( + self, + graph_module: GraphModule, + ): + for as_strided_copy_node in self.gen_pattern_as_strided_copy(graph_module): + for user in list(as_strided_copy_node.users.keys()): + user.replace_input_with( + as_strided_copy_node, as_strided_copy_node.args[0] + ) + graph_module.graph.erase_node(as_strided_copy_node) + + def _remove_useless( + self, + graph_module: GraphModule, + ): + for node in graph_module.graph.nodes: + if node.target not in self.USELESS_OP_SET: + continue + + # Prevent from removing if data type may change. + if ( + node.target == exir_ops.edge.aten._to_copy.default + or node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ) and "memory_format" not in node.kwargs: + continue + + for user in [user for user in node.users.keys()]: # noqa: C416 + user.replace_input_with(node, node.all_input_nodes[0]) + graph_module.graph.erase_node(node) + self._fold_as_strided_copy(graph_module) + + def call(self, graph_module: GraphModule): + self._remove_useless(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/utils.py b/backends/samsung/_passes/utils.py new file mode 100644 index 00000000000..afa7c72c601 --- /dev/null +++ b/backends/samsung/_passes/utils.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def none_quant_tensor_quant_meta(): + return { + "quant_dtype": torch.float32, + "scales": 1, + "zero_points": 0, + } diff --git a/backends/samsung/builders/__init__.py b/backends/samsung/builders/__init__.py index 02a457fd06e..978da82b370 100644 --- a/backends/samsung/builders/__init__.py +++ b/backends/samsung/builders/__init__.py @@ -14,11 +14,13 @@ op_clamp, op_constant_pad_nd, op_conv2d, + op_dequantize, op_div, op_embedding, op_expand_copy, op_gelu, op_getitem, + op_hardsigmoid, op_hardswish, op_hardtanh, op_layer_norm, @@ -32,6 +34,7 @@ op_mul, op_permute, op_pixel_shuffle, + op_quantize, op_relu, op_reshape, op_rsqrt, @@ -57,6 +60,7 @@ op_clamp, op_conv2d, op_constant_pad_nd, + op_dequantize, op_div, op_embedding, op_expand_copy, @@ -64,6 +68,7 @@ op_getitem, op_hardswish, op_hardtanh, + op_hardsigmoid, op_layer_norm, op_leaky_relu, op_linear, @@ -75,6 +80,7 @@ op_mul, op_permute, op_pixel_shuffle, + op_quantize, op_relu, op_reshape, op_rsqrt, diff --git a/backends/samsung/builders/node_visitor.py b/backends/samsung/builders/node_visitor.py index a35c0b4715d..0d2707da8f5 100644 --- a/backends/samsung/builders/node_visitor.py +++ b/backends/samsung/builders/node_visitor.py @@ -14,6 +14,7 @@ get_tensor_type, ) from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.samsung.utils.constants import QuantConstants from executorch.backends.transforms.utils import is_param_node from torch.export import ExportedProgram @@ -61,18 +62,26 @@ def define_tensor( dims = [1] if len(tensor.size()) == 0 else list(tensor.size()) + quant_attrs = node.meta.get("quantize_attrs") enn_tensor_id = enn_graph.define_tensor( node.name, dims, data_type, tensor_type.name, const_data, + quant_param=quant_attrs, ) assert enn_tensor_id is not None vals_to_ids[node] = enn_tensor_id return enn_tensor_id + def _update_params_qdtype(self, node: torch.fx.Node, params: Dict): + if qdtype := node.meta.get("quantize_attrs", {}).get( + QuantConstants.QUANT_KEY.quant_dtype + ): + params["quant_dtype"] = EnnGraph._affine_meta_param(qdtype) + _node_visitor_dict = {} @@ -92,6 +101,7 @@ def register_node_visitor(visitor): raise TypeError( f"target of vistor should be str|Tuple[str]|List[str], not{type(visitor.target)}" ) + return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: diff --git a/backends/samsung/builders/op_add.py b/backends/samsung/builders/op_add.py index 1b0dddb0d02..a6eb79897dd 100644 --- a/backends/samsung/builders/op_add.py +++ b/backends/samsung/builders/op_add.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -28,9 +29,13 @@ def define_node( ) -> None: input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTSUM", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTSUM", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_avg_pool2d.py b/backends/samsung/builders/op_avg_pool2d.py index ad7ccbac3ae..bfca8b89b22 100644 --- a/backends/samsung/builders/op_avg_pool2d.py +++ b/backends/samsung/builders/op_avg_pool2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -49,6 +50,7 @@ def define_node( params["stride_w"] = stride[1] params["padding"] = "EXPLICIT" params["explicit_padding"] = explicit_padding + self._update_params_qdtype(node, params) if len(node.args) > 4: ceil_mode = cast(bool, node.args[4]) @@ -64,7 +66,5 @@ def define_node( assert ( divisor_override == kernel_size[0] * kernel_size[1] ), "Not supported divisor_override which is not equal to pooling region." - output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "AVGPOOL2D", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_bmm.py b/backends/samsung/builders/op_bmm.py index 6ba8864ebb3..13e0d19cb14 100644 --- a/backends/samsung/builders/op_bmm.py +++ b/backends/samsung/builders/op_bmm.py @@ -16,7 +16,7 @@ @register_node_visitor class BMMVisitor(NodeVisitor): - target = "aten.bmm.default" + target = ["aten.bmm.default"] def __init__(self, *args) -> None: super().__init__(*args) @@ -29,12 +29,15 @@ def define_node( ) -> None: input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) enn_graph.define_op( - node.name, "BATCH_MATMUL", [input_id_1, input_id_2], [output_id] + node.name, "BATCH_MATMUL", [input_id_1, input_id_2], [output_id], params ) diff --git a/backends/samsung/builders/op_cat.py b/backends/samsung/builders/op_cat.py index e9c0a32b389..09387f2e361 100644 --- a/backends/samsung/builders/op_cat.py +++ b/backends/samsung/builders/op_cat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -12,6 +13,7 @@ ) from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph from executorch.backends.transforms import get_shape +from executorch.backends.transforms.utils import is_param_node @register_node_visitor @@ -29,14 +31,20 @@ def define_node( ) -> None: tensors = cast(List[torch.fx.Node], node.args[0]) input_tensor_ids = [] - - for in_tensor in tensors: + constant_idx = None + for idx, in_tensor in enumerate(tensors): + if is_param_node(self.exported_program, in_tensor): + assert constant_idx is None, "Only support at most 1 constant tensor" + constant_idx = idx input_id = self.define_tensor(in_tensor, enn_graph, vals_to_ids) input_tensor_ids.append(input_id) in_shape = get_shape(node) axis = cast(int, node.args[1]) % len(in_shape) if len(node.args) >= 2 else 0 params = {"axis": axis} + if constant_idx is not None: + params["constant_index"] = constant_idx + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op(node.name, "CONCAT", input_tensor_ids, [output_id], params) diff --git a/backends/samsung/builders/op_clamp.py b/backends/samsung/builders/op_clamp.py index c5670b80fa3..74af83212a5 100644 --- a/backends/samsung/builders/op_clamp.py +++ b/backends/samsung/builders/op_clamp.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict import torch @@ -32,12 +33,15 @@ def define_node( # The default value of lower bound and upper bound output_min = torch.finfo(torch.float32).min output_max = torch.finfo(torch.float32).max + if node.args[1] is not None: output_min = cast(float, node.args[1]) if len(node.args) > 2 and node.args[2] is not None: output_max = cast(float, node.args[2]) params = {"minimum": output_min, "maximum": output_max} + self._update_params_qdtype(node, params) + output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op(node.name, "CLIP", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_conv2d.py b/backends/samsung/builders/op_conv2d.py index 881a533801f..ab77d8df626 100644 --- a/backends/samsung/builders/op_conv2d.py +++ b/backends/samsung/builders/op_conv2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -56,6 +57,9 @@ def define_node( input_shape = get_shape(input) kernel_shape = get_shape(weight_node) params = {} + self._update_params_qdtype(node, params) + if "activation" in node.meta: + params["activation"] = node.meta["activation"] params["kernel_h"] = kernel_shape[2] params["kernel_w"] = kernel_shape[3] params["stride_h"] = stride[0] diff --git a/backends/samsung/builders/op_dequantize.py b/backends/samsung/builders/op_dequantize.py new file mode 100644 index 00000000000..a1c31af4037 --- /dev/null +++ b/backends/samsung/builders/op_dequantize.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.samsung.builders.node_visitor import register_node_visitor +from executorch.backends.samsung.builders.op_quantize import _QuantOpVistorBase + + +# Dequant ops here +@register_node_visitor +class DequantizeVistor(_QuantOpVistorBase): + target = [ + "quantized_decomposed.dequantize_per_tensor.default", + "quantized_decomposed.dequantize_per_tensor.tensor", + "quantized_decomposed.dequantize_per_channel.default", + "quantized_decomposed.dequantize_per_channel.tensor", + ] diff --git a/backends/samsung/builders/op_div.py b/backends/samsung/builders/op_div.py index 89d773ddb0e..8b0e7cdd5af 100644 --- a/backends/samsung/builders/op_div.py +++ b/backends/samsung/builders/op_div.py @@ -27,13 +27,16 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: - # inputs input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) - + params = {} + self._update_params_qdtype(node, params) # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTDIV", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTDIV", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_gelu.py b/backends/samsung/builders/op_gelu.py index 059a3b77850..88417f688f9 100644 --- a/backends/samsung/builders/op_gelu.py +++ b/backends/samsung/builders/op_gelu.py @@ -27,8 +27,14 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: - input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids) + # input1 + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "GELU", [input_id], [output_id]) + params = {} + self._update_params_qdtype(node, params) + + enn_graph.define_op(node.name, "GELU", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardsigmoid.py b/backends/samsung/builders/op_hardsigmoid.py new file mode 100644 index 00000000000..3a50d65da41 --- /dev/null +++ b/backends/samsung/builders/op_hardsigmoid.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class HardSigmoidVisitor(NodeVisitor): + target = "aten.hardsigmoid.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) + enn_graph.define_op(node.name, "HardSigmoid", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardswish.py b/backends/samsung/builders/op_hardswish.py index 72a99d17b83..8c30125e8a4 100644 --- a/backends/samsung/builders/op_hardswish.py +++ b/backends/samsung/builders/op_hardswish.py @@ -29,7 +29,7 @@ def define_node( ) -> None: input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) - + params = {} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - - enn_graph.define_op(node.name, "HARDSWISH", [input_id], [output_id]) + enn_graph.define_op(node.name, "HARDSWISH", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardtanh.py b/backends/samsung/builders/op_hardtanh.py index 4f667bf5299..7d65e97a566 100644 --- a/backends/samsung/builders/op_hardtanh.py +++ b/backends/samsung/builders/op_hardtanh.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict import torch @@ -29,9 +30,12 @@ def define_node( input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) + # default value of output_min and output_max output_min = cast(float, node.args[1]) if len(node.args) > 1 else -1 output_max = cast(float, node.args[2]) if len(node.args) > 2 else 1 + params = {"minimum": output_min, "maximum": output_max} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) diff --git a/backends/samsung/builders/op_layer_norm.py b/backends/samsung/builders/op_layer_norm.py index e6f853178d8..098bc92dc84 100644 --- a/backends/samsung/builders/op_layer_norm.py +++ b/backends/samsung/builders/op_layer_norm.py @@ -46,9 +46,8 @@ def define_node( epsilon = node.args[4] if len(node.args) > 4 else 1e-5 params = {"epsilon": epsilon} - + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op( node.name, "LAYERNORM", all_input_tensors, [output_id], params ) diff --git a/backends/samsung/builders/op_linear.py b/backends/samsung/builders/op_linear.py index 2f7aa1e6415..720439de976 100644 --- a/backends/samsung/builders/op_linear.py +++ b/backends/samsung/builders/op_linear.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -43,6 +44,7 @@ def define_node( weight_shape = get_shape(weight_node) params = {"in_channels": weight_shape[1], "out_channels": weight_shape[0]} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) diff --git a/backends/samsung/builders/op_max_pool2d.py b/backends/samsung/builders/op_max_pool2d.py index d386dd30b1a..57b716fcb34 100644 --- a/backends/samsung/builders/op_max_pool2d.py +++ b/backends/samsung/builders/op_max_pool2d.py @@ -73,6 +73,7 @@ def define_node( params["explicit_padding"] = explicit_padding params["dilation_h"] = dilation[0] params["dilation_w"] = dilation[1] + self._update_params_qdtype(node, params) if len(node.args) > 5: ceil_mode = cast(bool, node.args[5]) diff --git a/backends/samsung/builders/op_mean_dim.py b/backends/samsung/builders/op_mean_dim.py index 2f07f870ec4..3d0377703a7 100644 --- a/backends/samsung/builders/op_mean_dim.py +++ b/backends/samsung/builders/op_mean_dim.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -27,6 +28,7 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: + # input input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) @@ -37,8 +39,11 @@ def define_node( in_shape = get_shape(input) for dim in dims: reduce_axes.append(dim % len(in_shape)) - reduce_axes.sort() + + if len(node.args[1]) > 1: + reduce_axes.sort() keep_dim = node.args[2] if len(node.args) >= 3 else False params = {"keep_dims": keep_dim, "axis": reduce_axes} + self._update_params_qdtype(node, params) enn_graph.define_op(node.name, "REDUCEMEAN", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_mul.py b/backends/samsung/builders/op_mul.py index dce531ff0b0..6dd7c0dd9f0 100644 --- a/backends/samsung/builders/op_mul.py +++ b/backends/samsung/builders/op_mul.py @@ -1,5 +1,9 @@ -# Copyright (c) 2024 Samsung Electronics Co. LTD +# Copyright (c) 2025 Samsung Electronics Co. LTD # All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -23,11 +27,17 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: + input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTMUL", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTMUL", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_quantize.py b/backends/samsung/builders/op_quantize.py new file mode 100644 index 00000000000..dcf30e291f9 --- /dev/null +++ b/backends/samsung/builders/op_quantize.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.samsung.utils.constants import QuantConstants + + +class _QuantOpVistorBase(NodeVisitor): + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + # input + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + scales = node.args[1] + if isinstance(scales, torch.Tensor): + scales = scales.tolist() + elif not isinstance(scales, list): + scales = torch.tensor(scales).reshape([1]).tolist() + zero_points = node.args[2] + if isinstance(zero_points, torch.Tensor): + zero_points = zero_points.tolist() + elif not isinstance(zero_points, list): + zero_points = torch.tensor(zero_points).reshape([1]).tolist() + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + params = {"scales": scales, "zero_points": zero_points} + + if node.target in QuantConstants.QUANT_OPS_KEY_MAP: + enn_graph.define_op(node.name, "QUANTIZE", [input_id], [output_id], params) + else: + enn_graph.define_op( + node.name, "DEQUANTIZE", [input_id], [output_id], params + ) + + +@register_node_visitor +class QuantizeVistor(_QuantOpVistorBase): + target = [ + "quantized_decomposed.quantize_per_tensor.default", + "quantized_decomposed.quantize_per_channel.default", + ] diff --git a/backends/samsung/builders/op_relu.py b/backends/samsung/builders/op_relu.py index ba90116be1d..a4a2b6bc4f0 100644 --- a/backends/samsung/builders/op_relu.py +++ b/backends/samsung/builders/op_relu.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -30,5 +31,7 @@ def define_node( input_id = self.define_tensor(input, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) - enn_graph.define_op(node.name, "RELU", [input_id], [output_id]) + enn_graph.define_op(node.name, "RELU", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_softmax.py b/backends/samsung/builders/op_softmax.py index 1e2e4a378dc..7f569cea6fc 100644 --- a/backends/samsung/builders/op_softmax.py +++ b/backends/samsung/builders/op_softmax.py @@ -35,5 +35,5 @@ def define_node( axis = cast(int, node.args[1]) params = {"axis": axis} - + self._update_params_qdtype(node, params) enn_graph.define_op(node.name, "SOFTMAX", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_squeeze.py b/backends/samsung/builders/op_squeeze.py index d165a22fcb3..82fa17fbc95 100644 --- a/backends/samsung/builders/op_squeeze.py +++ b/backends/samsung/builders/op_squeeze.py @@ -33,4 +33,5 @@ def define_node( # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id]) + params = {"new_shape": [*node.meta["val"].shape]} + enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_to_copy.py b/backends/samsung/builders/op_to_copy.py index 545672ef6a3..c770602bb5f 100644 --- a/backends/samsung/builders/op_to_copy.py +++ b/backends/samsung/builders/op_to_copy.py @@ -11,6 +11,8 @@ NodeVisitor, register_node_visitor, ) + +from executorch.backends.samsung.builders.utils import get_map_dtype, get_tensor from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph @@ -35,5 +37,8 @@ def define_node( input_id = self.define_tensor(input, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + out_tensor = get_tensor(self.exported_program, node) + params["out_dtype"] = get_map_dtype(out_tensor.dtype) - enn_graph.define_op(node.name, "CAST", [input_id], [output_id]) + enn_graph.define_op(node.name, "CAST", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_unsqueeze.py b/backends/samsung/builders/op_unsqueeze.py index 942c3307de7..61fa06e6310 100644 --- a/backends/samsung/builders/op_unsqueeze.py +++ b/backends/samsung/builders/op_unsqueeze.py @@ -31,4 +31,5 @@ def define_node( output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id]) + params = {"new_shape": [*node.meta["val"].shape]} + enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_upsample_bilinear2d.py b/backends/samsung/builders/op_upsample_bilinear2d.py index a934b2789ba..d4b040460e3 100644 --- a/backends/samsung/builders/op_upsample_bilinear2d.py +++ b/backends/samsung/builders/op_upsample_bilinear2d.py @@ -46,6 +46,7 @@ def define_node( "upsampling_factor": scale_factor, "half_pixel_centers": True, } + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op( node.name, "RESIZE_BILINEAR", [input_id], [output_id], params diff --git a/backends/samsung/builders/utils.py b/backends/samsung/builders/utils.py index 58c84ff6d31..a640071c798 100644 --- a/backends/samsung/builders/utils.py +++ b/backends/samsung/builders/utils.py @@ -9,7 +9,6 @@ import torch from executorch.backends.samsung.utils.utils import is_graph_input, is_graph_output from executorch.backends.transforms.utils import get_param_tensor, is_param_node - from torch.export import ExportedProgram DATA_TYPE_STR_MAPPING = { diff --git a/backends/samsung/enn_preprocess.py b/backends/samsung/enn_preprocess.py index dde01bc09c7..0847ec0adeb 100644 --- a/backends/samsung/enn_preprocess.py +++ b/backends/samsung/enn_preprocess.py @@ -9,10 +9,16 @@ import executorch.backends.samsung.python.PyEnnWrapperAdaptor as PyEnnWrapper import torch +from executorch.backends.samsung._passes.annotate_qparams import AnnotateQparamsPass +from executorch.backends.samsung._passes.annotate_scalar_parameters import ( + AnnotateScalarParametersPass, +) from executorch.backends.samsung._passes.conv1d_to_conv2d import Conv1dToConv2d from executorch.backends.samsung._passes.customized_constant_prop import ( ConstantPropPass, ) +from executorch.backends.samsung._passes.fold_qdq import FoldQDQPass +from executorch.backends.samsung._passes.insert_qdq import InsertQDQPass from executorch.backends.samsung._passes.replace_scalar_ops import ReplaceOpsWithScalar from executorch.backends.samsung.builders.node_visitor import get_node_visitors from executorch.backends.samsung.serialization.compile_options import ( @@ -53,12 +59,16 @@ def preprocess( enn_preprocess_passes = PassManager( passes=[ + AnnotateQparamsPass(edge_program), + FoldQDQPass(), ConstantPropPass(edge_program), Conv1dToConv2d(edge_program), FuseBatchNormWithConvPass(edge_program), AddmmToLinearTransform(), ReplaceOpsWithScalar(), RemoveGetItemPass(), + InsertQDQPass(edge_program), + AnnotateScalarParametersPass(edge_program), ] ) pass_result = enn_preprocess_passes(edge_program.graph_module) diff --git a/backends/samsung/partition/enn_partitioner.py b/backends/samsung/partition/enn_partitioner.py index 952cb000429..368d069c380 100644 --- a/backends/samsung/partition/enn_partitioner.py +++ b/backends/samsung/partition/enn_partitioner.py @@ -129,5 +129,6 @@ def ops_to_not_decompose( torch.ops.aten.prelu.default, torch.ops.aten.layer_norm.default, torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.hardsigmoid.default, ] return (ops_not_to_decompose, None) diff --git a/backends/samsung/quantizer/__init__.py b/backends/samsung/quantizer/__init__.py new file mode 100644 index 00000000000..621eec69240 --- /dev/null +++ b/backends/samsung/quantizer/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .qconfig import Precision +from .quantizer import EnnQuantizer + +__all__ = [EnnQuantizer, Precision] diff --git a/backends/samsung/quantizer/annotator.py b/backends/samsung/quantizer/annotator.py new file mode 100644 index 00000000000..31015698006 --- /dev/null +++ b/backends/samsung/quantizer/annotator.py @@ -0,0 +1,871 @@ +# Copyright (c) Qualcomm Innovation Center, Inc +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, List + +import torch +from torch._ops import OpOverload +from torch._subclasses import FakeTensor + +from torch.fx import Graph, Node + +from torchao.quantization.pt2e import FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( + annotate_output_qspec, + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) + +from .qconfig import QuantizationConfig + +OP_ANNOTATOR: Dict[OpOverload, Callable] = {} + +ADD_OPS = [ + torch.ops.aten.add, + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, +] + + +def register_annotator(ops: List[OpOverload]): + def decorator(annotator: Callable): + for op in ops: + OP_ANNOTATOR[op] = annotator + + return decorator + + +def annotate(graph: Graph, quant_config: QuantizationConfig) -> None: + # Pattern annotation + _annotate_fused_activation_pattern(graph, quant_config) + + # Per-op annotation + for node in graph.nodes: + if node.op == "placeholder": + annotate_placeholder(node, quant_config) + elif node.op == "call_function": + annotate_func = OP_ANNOTATOR.get(node.target, None) + if annotate_func is not None: + annotate_func(node, quant_config) + + +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _is_fake_tensor(node: Node): + if ( + isinstance(node, Node) + and "val" in node.meta + and isinstance(node.meta["val"], FakeTensor) + ): + return True + return False + + +def _is_float_tensor(node: Node): + """Check if the node's tensor is a float tensor, + so that we can skip quantization for the node + since observers only works with float Tensors + """ + if not _is_fake_tensor(node): + return False + return node.meta["val"].dtype in [torch.float32, torch.float16] + + +def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +# for nodes whose targets ars placehold (not call_function) +def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + if _is_float_tensor(node): + annotate_output_qspec(node, quant_config.output_activation) + + _mark_nodes_as_annotated([node]) + + +# CASE 1: fused_activation case (ex. Conv2D + ReLU) +def _is_hardtanh_for_relux(relu_node: torch.fx.node.Node): + if relu_node.target in [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ]: + # checking if hardtanh is convertable to ReLU6 + # ReLU1 is not supported now + if not relu_node.args[1] == 0.0: + return False + if relu_node.args[2] == 6.0: # for ReLU6 + return True + return True + + +def _annotate_fused_activation_pattern( + graph: Graph, quant_config: QuantizationConfig +) -> None: + for relu_node in graph.nodes: + # Check relu/relu6 node + if relu_node.op != "call_function": + continue + if relu_node.target not in [ + # The strategy of ReLU and ReLU6 is fold_activation in ENNQuant + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.relu6.default, + torch.ops.aten.relu6_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ]: + continue + + if not _is_hardtanh_for_relux(relu_node): + continue + + producer_node = relu_node.args[0] + if not isinstance(producer_node, Node): + continue + if producer_node.op != "call_function": + continue + if len(producer_node.users) != 1: + continue + + # Handle affine + relu fusion + if producer_node.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ]: + # input & weight (or bias) setting for Conv node(producer_node) + quantization_annotation = producer_node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + input = producer_node.args[0] + quantization_annotation.input_qspec_map[input] = ( + quant_config.input_activation + ) + + quantization_annotation.input_qspec_map[producer_node.args[1]] = ( + quant_config.weight + ) + if len(producer_node.args) > 2 and quant_config.bias is not None: + quantization_annotation.input_qspec_map[producer_node.args[2]] = ( + quant_config.bias + ) + + producer_node.meta["quantization_annotation"] = quantization_annotation + producer_node.meta["quantization_annotation"]._annotated = True + # out setting for activation node (relu_node) + quantization_annotation = relu_node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = quant_config.output_activation + + relu_node.meta["quantization_annotation"] = quantization_annotation + relu_node.meta["quantization_annotation"]._annotated = True + continue + + +# CASE 2-1: two input case without Shared Quant +@register_annotator( + [ + torch.ops.aten.div, + torch.ops.aten.div.Tensor, + torch.ops.aten.divide.Tensor, + torch.ops.aten.matmul.default, + torch.ops.aten.bmm.default, + torch.ops.aten.sum.dim_IntList, + ] +) +def annotate_2in1out(node: Node, quant_config: QuantizationConfig) -> None: + input_act0 = node.args[0] + input_act1 = node.args[1] + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input_act0): + return + + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# getting QuantAnnot though the first input +def _get_quantization_annotation(node: Node): + if node.op == "placeholder": + return False + elif "quantization_annotation" in node.meta: + return node + elif node.args == (): + return False + elif isinstance(node.args[0], Node): + return _get_quantization_annotation(node.args[0]) + elif isinstance(node.args[0], list): + # for cat, concatenate and stack + if isinstance(node.args[0][0], Node): + return _get_quantization_annotation(node.args[0][0]) + else: + return False + else: + return False + + +# CASE 2-2: two input case with Shared Quant +# ops.add / ops.add_ are processed by another annotator +@register_annotator( + [ + torch.ops.aten.sub, + torch.ops.aten.mul, + torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sub_.Tensor, + torch.ops.aten.mul_.Tensor, + torch.ops.aten.rsub.Scalar, + torch.ops.aten.mul.Scalar, + ] +) +def annotate_2in1out_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + + input_qspec_map = {} + input0 = node.args[0] + input1 = node.args[1] + + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input0): + return + if ( + isinstance(input0, Node) + and isinstance(input1, float) + and not _get_quantization_annotation(input0) + ): + return + if ( + isinstance(input0, float) + and isinstance(input1, Node) + and not _get_quantization_annotation(input1) + ): + return + if isinstance(input0, Node) and isinstance(input1, Node): + shared_qspec = SharedQuantizationSpec((input0, node)) + input_qspec_map[input0] = quant_config.input_activation + input_qspec_map[input1] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + else: + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# CASE 2-3: only for add ops +@register_annotator(ADD_OPS) +def annotate_add_ops_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + + input_qspec_map = {} + input0 = node.args[0] + input1 = node.args[1] + + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input0): + return + + if isinstance(input0, Node) and isinstance(input1, Node): + NonQuantShare_ops_for_add = [torch.ops.aten.dropout.default] + ADD_OPS + if ( + input0.op == "call_function" and input0.target in NonQuantShare_ops_for_add + ) or ( + input1.op == "call_function" and input1.target in NonQuantShare_ops_for_add + ): + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + shared_qspec = SharedQuantizationSpec((input0, node)) + input_qspec_map[input0] = quant_config.input_activation + input_qspec_map[input1] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + elif ( + isinstance(input0, Node) + and isinstance(input1, float) + and not _get_quantization_annotation(input0) + ): + pass + elif ( + isinstance(input0, float) + and isinstance(input1, Node) + and not _get_quantization_annotation(input1) + ): + pass + else: + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# CASE 3-1: Single input + Single Out case without Shared Quant +@register_annotator( + [ + torch.ops.aten.ceil.default, + torch.ops.aten.clamp.default, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.relu6.default, + torch.ops.aten.relu6_.default, + torch.ops.aten.cos.default, + torch.ops.aten.sin.default, + torch.ops.aten.tanh.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish_.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.mean.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.leaky_relu_.default, + torch.ops.aten.prelu.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.mean.dim, + torch.ops.aten.sqrt.default, + torch.ops.aten.gelu.default, + torch.ops.aten.scaled_dot_product_attention.default, + torch.ops.aten.rsqrt.default, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.topk.default, + ] +) +def annotate_1in1out(node: Node, quant_config: QuantizationConfig) -> None: + # skipping quantization if input is not float. + if _is_annotated([node]) or not _is_float_tensor(node.args[0]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + # one inputs + one output case. + input_act_qspec = quant_config.input_activation + quantization_annotation.input_qspec_map[node.args[0]] = input_act_qspec + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 3-2: Single input + Single Out case with Shared Quant +@register_annotator( + [ + torch.ops.aten.permute.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.transpose.int, + torch.ops.aten.expand.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.reshape.default, + torch.ops.aten.select.int, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.pad.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.to.dtype, + ] +) +def annotate_1in1out_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + if _is_annotated([node]) or not _is_float_tensor(input): + return + + shared_qspec = SharedQuantizationSpec((input, node)) + + # get QuantAnnot from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + input_qspec_map[shared_quant_node] = SharedQuantizationSpec(shared_quant_node) + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + else: + # if no QuantAnnot in the input path + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 3-3: Single input + Single Out case with FP +@register_annotator( + [ + torch.ops.aten.softmax.int, + torch.ops.aten._softmax.default, + torch.ops.aten._safe_softmax.default, + torch.ops.aten.log_softmax.int, + ] +) +def annotate_1in1out_with_SharedQuant_for_FP( + node: Node, quant_config: QuantizationConfig +) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + if input.target in ADD_OPS and _is_annotated([input]): + del input.meta["quantization_annotation"] + + # get QuantAnnot from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + # if QuantAnnot in the input path, input_qspec is shared, but output_qspec is not. + input_qspec_map[shared_quant_node] = SharedQuantizationSpec(shared_quant_node) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quant_config.output_activation, + _annotated=True, + ) + else: + # if no QuantAnnot in the input path + node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=quant_config.output_activation, + _annotated=True, + ) + + +# CASE 4: One value input + one index input with Shared Quant +@register_annotator([torch.ops.aten.index.Tensor]) +def annotate_index(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnt from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + input_qspec_map[input] = quant_config.input_activation + + # sharing QuantAnnot with the parent + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 5 input + index + value & output with Shared Quant +@register_annotator( + [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] +) +def annotate_index_put(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] # from KVCache in LLAMA + value = node.args[2] # from linear projection layer + assert isinstance(input, Node) + assert isinstance(value, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnot from input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + input_qspec_map[input] = shared_qspec + input_qspec_map[value] = shared_qspec + output_qspec = shared_qspec + else: + # if no QuantAnnot in input path, asign the default QuantAnnot from quant_config. + input_qspec_map[input] = quant_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + output_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + + +# CASE 6 unbind + getitem case +# (inputQuant--unbinde--no Qunat) --> (no Qunat--getitem--outputQuant) +@register_annotator([torch.ops.aten.unbind.int]) +def annotate_unbind(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnot from input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + else: + # if no QuantAnnot in input path, asign the default QuantAnnot from quant_config. + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + for users_node in node.users: + users_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 7: stand-alone Conv2d and Conv1d +@register_annotator( + [ + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.linear.default, + ] +) +def annotate_conv2d(node: Node, quant_config: QuantizationConfig) -> None: + # skipping quantization if weights are not float + if _is_annotated([node]) or not _is_float_tensor(node.args[1]): + return + + input = node.args[0] + # input & weight (or bias) setting for Conv node(producer_node) + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + quantization_annotation.input_qspec_map[input] = SharedQuantizationSpec( + shared_quant_node + ) + else: + quantization_annotation.input_qspec_map[input] = quant_config.input_activation + quantization_annotation.input_qspec_map[node.args[1]] = quant_config.weight + if len(node.args) > 2 and quant_config.bias is not None: + quantization_annotation.input_qspec_map[node.args[2]] = quant_config.bias + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 8: embedding +@register_annotator([torch.ops.aten.embedding.default]) +def annotate_embedding(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + weight = node.args[0] + if _is_annotated([node]) or not _is_float_tensor(weight): + return + + input_qspec_map[weight] = quant_config.input_activation + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quant_config.output_activation, + _annotated=True, + ) + + +# CASE 9: Concat & Stack +@register_annotator( + [ + torch.ops.aten.cat.default, + torch.ops.aten.concat.default, + torch.ops.aten.stack.default, + ] +) +def annotate_cat(node: Node, quant_config: QuantizationConfig) -> None: + inputs = node.args[0] + first_input = inputs[0] + assert isinstance(inputs, list) + assert isinstance(first_input, Node) + + if _is_annotated([node]) or not _is_float_tensor(first_input): + return + + input_qspec_map = {} + shared_qspec = SharedQuantizationSpec((first_input, node)) + for input in inputs: + if input == first_input: + input_qspec_map[input] = quant_config.input_activation + else: + input_qspec_map[input] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 10: various normalizations +@register_annotator([torch.ops.aten.rms_norm.default]) +def annotate_rms_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten.group_norm.default]) +def annotate_group_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.weight + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten.layer_norm.default]) +def annotate_layer_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) +def annotate_batch_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + + quantization_annotation.input_qspec_map[node.args[1]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 11: Sigmoid +@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default]) +def annotate_sigmoid(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quant_config.input_activation + + assert isinstance(input_act, Node) + out_qconf = quant_config.output_activation + + q_max = ( + torch.iinfo(out_qconf.dtype).max + if out_qconf.quant_max is None + else out_qconf.quant_max + ) + q_min = ( + torch.iinfo(out_qconf.dtype).min + if out_qconf.quant_min is None + else out_qconf.quant_min + ) + + scale = 1 / (q_max - q_min + 1) + + bias_obs_ctr = FixedQParamsObserver.with_args( + scale=scale, + zero_point=0, + dtype=quant_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=q_max, + quant_min=q_min, + ) + + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quant_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=bias_obs_ctr, + qscheme=torch.torch.per_tensor_affine, + ) + + if _is_float_tensor(node): + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=out_act_quantization_spec, + _annotated=True, + ) diff --git a/backends/samsung/quantizer/qconfig.py b/backends/samsung/quantizer/qconfig.py new file mode 100644 index 00000000000..f32c8d39796 --- /dev/null +++ b/backends/samsung/quantizer/qconfig.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import IntEnum, unique +from typing import Callable, Optional + +import torch +from torchao.quantization.pt2e import ( + FakeQuantize, + MinMaxObserver, + PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import QuantizationSpec + + +@unique +class Precision(IntEnum): + A8W8 = 3 + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def get_quant_config( + precision: Precision, + is_per_channel: bool = False, + is_qat: bool = False, +) -> QuantizationConfig: + + precision_mappings = { + Precision.A8W8: get_a8w8_enn_quant_config, + } + if precision not in precision_mappings: + raise RuntimeError("Unrecognized precision setting.") + + is_weight_symm = is_per_channel + + qconfig_fn = precision_mappings[precision] + return qconfig_fn(is_per_channel, is_qat, wei_symmetric=is_weight_symm) + + +def _get_activation_qspec( + dtype, + is_symmetric, + is_qat, + observer_cls=MinMaxObserver, + quant_min=None, + quant_max=None, +): + eps_value = 2**-12 + if quant_max is None: + quant_max = torch.iinfo(dtype).max + if quant_min is None: + quant_min = torch.iinfo(dtype).min + + qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine + if is_qat: + observer_or_fake_quant = FakeQuantize.with_args( + observer=observer_cls, eps=eps_value + ) + else: + observer_or_fake_quant = observer_cls.with_args(eps=eps_value) + + return QuantizationSpec( + dtype=dtype, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, + observer_or_fake_quant_ctr=observer_or_fake_quant, + ) + + +def _get_weight_qspec( + dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None +): + assert is_symmetric or not is_per_channel, "Not support asymm+perchannel mode" + + eps_value = 2**-12 + + if quant_max is None: + quant_max = torch.iinfo(dtype).max + if quant_min is None: + quant_min = torch.iinfo(dtype).min + + if not is_per_channel: + qscheme = ( + torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine + ) + observer_cls = MinMaxObserver + else: + qscheme = ( + torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine + ) + observer_cls = PerChannelMinMaxObserver + + if is_qat: + observer_or_fake_quant = FakeQuantize.with_args( + observer=observer_cls, eps=eps_value + ) + else: + observer_or_fake_quant = observer_cls.with_args(eps=eps_value) + + return QuantizationSpec( + dtype=dtype, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, + ch_axis=0, + observer_or_fake_quant_ctr=observer_or_fake_quant, + ) + + +def get_a8w8_enn_quant_config( + is_per_channel=True, is_qat=False, act_symmetric=False, wei_symmetric=False +) -> QuantizationConfig: + act_quantization_spec = _get_activation_qspec(torch.int8, act_symmetric, is_qat) + wgt_quantization_spec = _get_weight_qspec( + torch.int8, wei_symmetric, is_per_channel, is_qat + ) + bias_quantization_spec = None + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=wgt_quantization_spec, + bias=bias_quantization_spec, + ) + return quantization_config + + +class QuantInfo: + def __init__(self, torch_dtype: torch.dtype, string: str): + self._torch_dtype = torch_dtype + self._string = string + + @property + def torch_dtype(self): + return self._torch_dtype + + @property + def string(self): + return self._string + + +class QuantInfoManager: + QUANT_INFO_MAP = { + Precision.A8W8: (QuantInfo(torch.int8, "INT8"), QuantInfo(torch.int8, "INT8")), + } + FP_INFO = ( + QuantInfo(torch.float32, "FLOAT32"), + QuantInfo(torch.float32, "FLOAT32"), + ) + + def __init__(self): + self.precision = None + + def set_precision(self, precision: Precision): + self.precision = precision + + @property + def weight_precison(self) -> Optional[QuantInfo]: + return self.QUANT_INFO_MAP.get(self.precision, self.FP_INFO)[0] + + @property + def act_precision(self) -> Optional[QuantInfo]: + return self.QUANT_INFO_MAP.get(self.precision, self.FP_INFO)[1] diff --git a/backends/samsung/quantizer/quantizer.py b/backends/samsung/quantizer/quantizer.py new file mode 100644 index 00000000000..cf46677d000 --- /dev/null +++ b/backends/samsung/quantizer/quantizer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Sequence + +import torch +from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer + +from .annotator import annotate +from .qconfig import get_quant_config, Precision, QuantInfoManager + + +global_quant_info = QuantInfoManager() + + +class EnnQuantizer(Quantizer): + + def __init__(self): + super().__init__() + + self._precision = Precision.A8W8 + global_quant_info.set_precision(self._precision) + self._is_per_channel = True + self._is_qat = False + self.custom_quant_annotations: Sequence[Callable] = [] + + def setup_precision(self, quant_dtype: Precision) -> None: + assert quant_dtype in Precision, f"No support for Precision {quant_dtype}." + self._precision = quant_dtype + global_quant_info.set_precision(self._precision) + + def setup_quant_params( + self, quant_dtype: Precision, is_per_channel=True, is_qat=False + ) -> None: + assert quant_dtype in Precision, f"No support for Precision {quant_dtype}." + self._precision = quant_dtype + self._is_per_channel = is_per_channel + self._is_qat = is_qat + + def annotate(self, model: GraphModule) -> GraphModule: + self._annotate(model) + self._annotate_custom_annotation(model) + return model + + def _annotate(self, gm: GraphModule) -> None: + quant_config = get_quant_config( + self._precision, self._is_per_channel, self._is_qat + ) + annotate(gm.graph, quant_config) + + def add_custom_quant_annotations( + self, custom_quant_annotations: Sequence[Callable] + ) -> None: + self.custom_quant_annotations = custom_quant_annotations + + def _annotate_custom_annotation(self, gm: GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) + + def validate(self, model: torch.fx.GraphModule) -> None: + return diff --git a/backends/samsung/serialization/enn_graph_schema.py b/backends/samsung/serialization/enn_graph_schema.py index 7e74182f9d7..5209a8672ee 100644 --- a/backends/samsung/serialization/enn_graph_schema.py +++ b/backends/samsung/serialization/enn_graph_schema.py @@ -5,13 +5,16 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import executorch.backends.samsung.python.PyGraphWrapperAdaptor as PyGraphWrapper import numpy as np import torch +from executorch.backends.samsung.builders.utils import DATA_TYPE_STR_MAPPING +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.samsung.utils.utils import quantize_tensor class EnnGraph: @@ -24,6 +27,10 @@ def __init__(self): self.inputs = [] self.outputs = [] + def init(self, name: str, soc_name): + self.name = name + self.soc_name = soc_name + def define_op( self, name, @@ -46,22 +53,54 @@ def define_op( py_param_wrapper.SetScalarValue(params[key]) else: logging.error("Unsupported param type.") + # Set op.AddOpParam(py_param_wrapper) self.graph.DefineOpNode(op) - def define_tensor( + def define_tensor( # noqa: C901 self, name: str, shape: List, data_type: str, tensor_type: str, data: Optional[Union[np.ndarray, torch.Tensor]] = None, + quant_param: Optional[Dict[str, Any]] = None, ) -> int: layout = "NCHW" if len(shape) == 4 else "UNDEFINED" + if quant_param is not None: + data_type = DATA_TYPE_STR_MAPPING[ + quant_param[QuantConstants.QUANT_KEY.quant_dtype] + ] + tensor = PyGraphWrapper.PyEnnTensorWrapper(name, shape, data_type, layout) + if quant_param is not None: + need_quantize = True + + scales = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.scale] + ) + zero_points = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.zero_point] + ) + q_dtype = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.quant_dtype] + ) + tensor.AddQuantizeParam(q_dtype, scales, zero_points) + + if need_quantize and data is not None: + if isinstance(data, np.ndarray): + data = torch.tensor(data) + data = quantize_tensor( + data, + scales, + zero_points, + quant_param[QuantConstants.QUANT_KEY.quant_dtype], + axis=quant_param.get("axis"), + ) + if data is not None: if isinstance(data, torch.Tensor): data = data.detach().numpy() @@ -83,3 +122,20 @@ def finish(self): def serialize(self): return self.graph.Serialize() + + @staticmethod + def _affine_meta_param(param: Any) -> str: + type_str_affine_table = { + torch.int8: "AINT8", + } + if isinstance(param, str): + return param + if isinstance(param, (float, int)): + return [param] + if hasattr(param, "tolist"): + return param.tolist() + if isinstance(param, torch.dtype): + # Convenient for debugging + param = type_str_affine_table.get(param, "") + + return param diff --git a/backends/samsung/utils/constants.py b/backends/samsung/utils/constants.py new file mode 100644 index 00000000000..7c3997b9fe2 --- /dev/null +++ b/backends/samsung/utils/constants.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantConstants: + # TODO: check keys + class QUANT_KEY: + scale = "scales" + zero_point = "zero_points" + quant_min = "quant_min" + quant_max = "quant_max" + quant_dtype = "quant_dtype" + + PERCHANNEL_KEY_MAP = { + "scales": QUANT_KEY.scale, + "zero_points": QUANT_KEY.zero_point, + "quant_min": QUANT_KEY.quant_min, + "quant_max": QUANT_KEY.quant_max, + "dtype": QUANT_KEY.quant_dtype, + } + # SNC ir always use key 'scales' and 'zero_points' + PERTENSOR_KEY_MAP = { + "scale": QUANT_KEY.scale, + "zero_point": QUANT_KEY.zero_point, + "quant_min": QUANT_KEY.quant_min, + "quant_max": QUANT_KEY.quant_max, + "dtype": QUANT_KEY.quant_dtype, + } + + QUANT_OPS_KEY_MAP = { + exir_ops.edge.quantized_decomposed.quantize_per_channel.default: PERCHANNEL_KEY_MAP, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: PERTENSOR_KEY_MAP, + } + + DEQUANT_OPS_KEY_MAP = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: PERCHANNEL_KEY_MAP, + } diff --git a/backends/samsung/utils/export_utils.py b/backends/samsung/utils/export_utils.py index aaf407ef0b3..39992f2ea2a 100644 --- a/backends/samsung/utils/export_utils.py +++ b/backends/samsung/utils/export_utils.py @@ -4,20 +4,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +import logging +from typing import List, Optional, Tuple import executorch.exir as exir import torch +from executorch.backends.samsung._passes.fuse_conv_act import FuseConvActPass +from executorch.backends.samsung._passes.remove_useless_ops import RemoveUselessOpPass from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer.quantizer import EnnQuantizer, Precision +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_details import CompileSpec - from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_manager import PassType from executorch.exir.program._program import to_edge_transform_and_lower +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def get_edge_compile_config(): + # Maybe most ops in non-decomposition list should be added here + # TODO: to confirm whether all op in none-decomposed table should be added here return EdgeCompileConfig( _skip_dim_order=True, _core_aten_ops_exception_list=[ @@ -29,24 +39,55 @@ def get_edge_compile_config(): exir_ops.edge.aten._safe_softmax.default, exir_ops.edge.aten.layer_norm.default, exir_ops.edge.aten.matmul.default, + exir_ops.edge.aten.hardsigmoid.default, ], ) +def get_enn_pass_list() -> List[PassType]: + return [ + RemoveUselessOpPass(), + RemoveCloneOpsTransform(), + FuseConvActPass(), + ] + + +def quantize_module( + module: torch.nn.Module, + inputs, + calibration_dataset, + precision: Precision, + is_per_channel: bool = True, + is_qat: bool = False, +) -> torch.nn.Module: + quantizer = EnnQuantizer() + quantizer.setup_quant_params(precision, is_per_channel, is_qat) + logging.info("Export nn module for quantization...") + exported_module = torch.export.export_for_training(module, inputs).module() + DecomposeScaledDotProductAttention()(exported_module) + logging.info("Quantizing the module...") + annotated_module = prepare_pt2e(exported_module, quantizer) + for data in calibration_dataset: + annotated_module(*data) + quantized_module = convert_pt2e(annotated_module, fold_quantize=False) + logging.info("Quantizing finished.") + return quantized_module + + def to_edge_transform_and_lower_to_enn( module: torch.nn.Module, inputs: Tuple[torch.Tensor], + custom_pass_config: List[PassType] = None, compile_specs: Optional[CompileSpec] = None, ) -> exir.ExecutorchProgramManager: - assert ( - compile_specs is not None - ), "Please provide compile specifications for enn backend" + assert compile_specs is not None, "For now, we must deliver complile specs" prog = torch.export.export(module, inputs) - - ahead_pass_list = [RemoveCloneOpsTransform()] + pass_list = get_enn_pass_list() + if custom_pass_config: + pass_list.extend(custom_pass_config) return to_edge_transform_and_lower( prog, - ahead_pass_list, + pass_list, {"forward": [EnnPartitioner(compile_specs)]}, compile_config=get_edge_compile_config(), ) diff --git a/backends/samsung/utils/utils.py b/backends/samsung/utils/utils.py index 5da9808f38f..bbbec518b2a 100644 --- a/backends/samsung/utils/utils.py +++ b/backends/samsung/utils/utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Optional, Tuple import torch from executorch.backends.transforms.utils import is_param_node from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -35,3 +36,90 @@ def is_graph_output(node: torch.fx.Node) -> bool: ): return True return False + + +def _quantize_per_tensor( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]], +): + assert ( + len(scales) == 1 + ), "For per-tensor quantization, there should be only one scale/zeropoint" + return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + in_tensor, + torch.Tensor(scales), + torch.Tensor(zeropoints), + qrange[0], + qrange[1], + dtype, + ) + + +def _quantize_per_channel( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]], + axis: Optional[int], # Only for per-channel +): + assert ( + len(scales) == in_tensor.shape[axis] + ), "Shape not match for quant params and input tensor" + return exir_ops.edge.quantized_decomposed.quantize_per_channel.default( + in_tensor, + torch.Tensor(scales), + torch.Tensor(zeropoints), + axis, + qrange[0], + qrange[1], + dtype, + ) + + +def quantize_tensor( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]] = None, + axis: Optional[int] = None, # Only for per-channel +) -> torch.Tensor: + """ + To quantize constant tensor by executorch OPs. If `axis` not set, we quantize the tensor by per tensor. + If `axis` was set, we do per-channel quantize. + + :param in_tensor: The tensor to be quantized + :param scales: List of scales. For per-tensor quantization, it should contain only one element + :param zeropoints: List of zeropoints. For per-tensor quantization, it should contain only one element + :param dtype: The output dtype + :param qrange: The quantization range (qmin, qmax). + If not set, we will get the maximum range of the dtype by `torch.iinfo` + :param axis: We do per-channel quantize by which axis. + Only when this parameter set, we do per-channel quantization + :type in_tensor: torch.Tensor + :type scalse: List[float] + :type zeropoints: List[int] + :type dtype: torch.dtype + :type qrange: Optional[Tuple[int,int]] + :type axis: Optional[int] + :return: The quantized tensor + """ + assert len(scales) == len( + zeropoints + ), "scales should have same shape with zeropoints" + if not qrange: + qrange = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + + if axis is not None: + return _quantize_per_channel(in_tensor, scales, zeropoints, dtype, qrange, axis) + return _quantize_per_tensor( + in_tensor, + scales, + zeropoints, + dtype, + qrange, + ) diff --git a/examples/samsung/scripts/deeplab_v3.py b/examples/samsung/scripts/deeplab_v3.py new file mode 100644 index 00000000000..b1e8fef65fe --- /dev/null +++ b/examples/samsung/scripts/deeplab_v3.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from typing import Optional + +import torch +import torchvision.transforms.v2 as vision_transform_v2 + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet50Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program +from torchvision.datasets import VOCSegmentation + + +def get_dataset( + data_dir: str, + calinum=100, + input_transform_compose: Optional[vision_transform_v2.Compose] = None, + target_transform_compose: Optional[vision_transform_v2.Compose] = None, +): + if not input_transform_compose: + input_transform_compose = vision_transform_v2.Compose( + [ + vision_transform_v2.Resize([224, 224]), + vision_transform_v2.ToImage(), + vision_transform_v2.ToDtype(torch.float32, scale=True), + vision_transform_v2.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + vision_transform_v2.Lambda(lambda x: x.unsqueeze(0)), # Add batch dim + ] + ) + if not target_transform_compose: + target_transform_compose = vision_transform_v2.Compose( + [ + vision_transform_v2.Resize([224, 224]), + vision_transform_v2.ToImage(), + vision_transform_v2.ToDtype(torch.long, scale=False), + vision_transform_v2.Lambda(lambda x: x.unsqueeze(0)), # Add batch dim + ] + ) + voc_dataset = VOCSegmentation( + data_dir, + "2012", + "val", + transform=input_transform_compose, + target_transform=target_transform_compose, + ) + example_input = [ + (voc_dataset[i][0],) for i in range(min(calinum, len(voc_dataset))) + ] + return example_input + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=("path to the validation folder of VOC dataset. "), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./deeplab_v3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "deeplab_v3" + instance = DeepLabV3ResNet50Model() + model = DeepLabV3ResNet50Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs = get_dataset( + data_dir=f"{args.dataset}", + calinum=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/edsr.py b/examples/samsung/scripts/edsr.py new file mode 100644 index 00000000000..f300a9c8547 --- /dev/null +++ b/examples/samsung/scripts/edsr.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from typing import List, Optional, Tuple + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.edsr import EdsrModel +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + +from torchsr import transforms + + +def get_dataset( + root_dir: str, + calinum=100, + transform_compose: Optional[transforms.Compose] = None, +) -> Tuple: + """ + Generate test data from B100 dataset for quantization model + + :param root_dir: Dir of dataset. The real dataset should be in root_dir/SRBenchmarks/benchmark/ + :param dataset_name: data_set name + :param testnum: Number of test data. Default 500 + :param transform_compose: Transforms to be applied to data. + Default: + transform_compose = transforms.Compose( + [transforms.ToTensor()] # Convert Pillows Image to tensor + ) + :type root_dir: str + :type calinum: int + :type testnum: int + :type transform_compose: transforms.Compose | None + :return: (example_input, cali_data, test_data) + """ + + class SrResize: + def __init__(self, expected_size: List[List[int]]): + self.expected_size = expected_size + + def __call__(self, x): + return ( + x[0].resize(self.expected_size[0]), + x[1].resize(self.expected_size[1]), + ) + + class SrUnsqueeze: + def __call__(self, x): + return ( + x[0].unsqueeze(0), + x[1].unsqueeze(0), + ) + + if not transform_compose: + transform_compose = transforms.Compose( + [ + SrResize([[448, 448], [224, 224]]), + transforms.ToTensor(), # Convert Pillows Image to tensor + SrUnsqueeze(), + ] + ) + from torchsr.datasets import B100 + + dataset = B100(root=root_dir, transform=transform_compose, scale=2) + example_data = [(dataset[i][1],) for i in range(min(calinum, len(dataset)))] + return example_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=("path to the validation folder of B100"), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./edsr", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "edsr" + instance = EdsrModel() + model = EdsrModel().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs = get_dataset( + root_dir=f"{args.dataset}", + calinum=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/inception_v3.py b/examples/samsung/scripts/inception_v3.py new file mode 100644 index 00000000000..77540285eab --- /dev/null +++ b/examples/samsung/scripts/inception_v3.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.inception_v3 import InceptionV3Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./inception_v3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "inception_v3" + instance = InceptionV3Model() + model = InceptionV3Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/inception_v4.py b/examples/samsung/scripts/inception_v4.py new file mode 100644 index 00000000000..3140682998c --- /dev/null +++ b/examples/samsung/scripts/inception_v4.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.inception_v4 import InceptionV4Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (299, 299) + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./inception_v4", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "inception_v4" + instance = InceptionV4Model() + model = InceptionV4Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/mobilenet_v2.py b/examples/samsung/scripts/mobilenet_v2.py new file mode 100644 index 00000000000..7c69de38e2c --- /dev/null +++ b/examples/samsung/scripts/mobilenet_v2.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.mobilenet_v2 import MV2Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./mobilenetV2", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "mobilenetV2_enn" + instance = MV2Model(False) + model = MV2Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/mobilenet_v3.py b/examples/samsung/scripts/mobilenet_v3.py new file mode 100644 index 00000000000..3cc8eadf633 --- /dev/null +++ b/examples/samsung/scripts/mobilenet_v3.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.mobilenet_v3 import MV3Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./mobilenet_v3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "mobilenet_v3" + instance = MV3Model() + model = MV3Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/resnet18.py b/examples/samsung/scripts/resnet18.py new file mode 100644 index 00000000000..2f3233214ce --- /dev/null +++ b/examples/samsung/scripts/resnet18.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.resnet import ResNet18Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./resnet18", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "resnet18" + instance = ResNet18Model() + model = ResNet18Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/resnet50.py b/examples/samsung/scripts/resnet50.py new file mode 100644 index 00000000000..1d6c348b641 --- /dev/null +++ b/examples/samsung/scripts/resnet50.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.resnet import ResNet50Model +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./resnet50", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "resnet50" + instance = ResNet50Model() + model = ResNet50Model().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/vit.py b/examples/samsung/scripts/vit.py new file mode 100644 index 00000000000..19c22c473cd --- /dev/null +++ b/examples/samsung/scripts/vit.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.torchvision_vit import TorchVisionViTModel +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + image_shape = (256, 256) + crop_size = 224 + shuffle = True + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(image_shape), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=shuffle, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./vision_transformer", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "vision_transformer" + instance = TorchVisionViTModel() + model = TorchVisionViTModel().get_eager_model().eval() + assert args.calibration_number + if args.dataset: + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + target = None + input_list = None + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact) diff --git a/examples/samsung/scripts/wav2letter.py b/examples/samsung/scripts/wav2letter.py new file mode 100644 index 00000000000..33069105d99 --- /dev/null +++ b/examples/samsung/scripts/wav2letter.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from typing import List + +import torch + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer import Precision +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.backends.samsung.utils.export_utils import ( + quantize_module, + to_edge_transform_and_lower_to_enn, +) +from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.examples.samsung.utils import save_tensors +from executorch.exir import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +class DataManager: + class Encoder: + def __init__(self, vocab, blank_label="*"): + self.vocab = vocab + self.char_to_id = {c: i for i, c in enumerate(vocab)} + self.blank_label = blank_label + + def encode(self, text): + return [self.char_to_id[c] for c in text.lower()] + + @classmethod + def _get_voice_dataset( + cls, data_size: int, data_dir: str, labels: List[str], fixed_token_num: int + ): + from torch.utils.data import DataLoader + from torchaudio.datasets import LIBRISPEECH + + def collate_fun(batch, encode_fn, mode="train"): + waves = [] + text_ids = [] + input_lengths = [] + output_lengths = [] + + if mode == "train": + shifts = torch.randn(len(batch)) > 0.0 + + for i, (wave, _, text, *_) in enumerate(batch): + if mode == "train" and shifts[i]: + wave = wave[:, 160:] + waves.append(wave[0]) + ids = torch.LongTensor(encode_fn(text)) + text_ids.append(ids) + input_lengths.append(wave.size(1) // 320) + output_lengths.append(len(ids)) + + waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True).unsqueeze( + 1 + ) + labels = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True) + + return waves, labels, input_lengths, output_lengths + + encoder = cls.Encoder(labels) + + testset_url = "test-clean" + dataset = LIBRISPEECH(data_dir, url=testset_url) + data_loader = DataLoader( + dataset=dataset, + batch_size=1, + shuffle=True, + collate_fn=lambda x: collate_fun(x, encoder.encode, "valid"), + ) + # prepare input data + inputs, targets = [], [] + in_lens, tar_lens = [], [] + + def _loader(): + for waves, labels, inputs_len, targets_len in data_loader: + if inputs_len[0] >= fixed_token_num: + continue + zero_padding = torch.zeros( + [1, 1, fixed_token_num * 320 - waves.shape[2]] + ) + waves = torch.concat((waves, zero_padding), axis=2) + yield waves, labels, [fixed_token_num + 1], targets_len + + for i, (waves, labels, inputs_len, targets_len) in enumerate( + _loader() + ): # waves, labels, input_lens, output_lens + inputs.append(waves) + targets.append(labels) + in_lens.append(inputs_len) + tar_lens.append(targets_len) + if i >= data_size: + break + + return inputs, targets, in_lens, tar_lens + + @classmethod + def get_dataset( + cls, + data_dir: str, + calinum=100, + fixed_out_token=300, + labels=None, + ): + if labels is None: + labels = [" ", *"abcdefghijklmnopqrstuvwxyz", "'", "*"] + dataset = cls._get_voice_dataset(calinum, data_dir, labels, fixed_out_token) + example_input = [(dataset[0][i],) for i in range(min(calinum, len(dataset[0])))] + return example_input + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--chipset", + default="E9955", + help="Samsung chipset, i.e. E9945, E9955, etc", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + default=None, + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + ) + + parser.add_argument( + "-p", + "--precision", + default=None, + help=("Quantizaiton precision. If not set, the model will not be quantized."), + choices=[None, "A8W8"], + type=str, + ) + + parser.add_argument( + "-cn", + "--calibration_number", + default=100, + help=( + "Assign the number of data you want " + "to use for calibrating the quant params." + ), + type=int, + ) + + parser.add_argument( + "--dump", + default=False, + const=True, + nargs="?", + help=("Whether to dump all outputs. If not set, we only dump pte."), + type=bool, + ) + + parser.add_argument( + "-w", + "--weight", + default=None, + help="Absolute path of retrained w2l weight (With .pt format), the vocab size should 29", + type=str, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. ", + default="./wav2letter", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "wav2letter" + instance = Wav2LetterModel() + instance.vocab_size = 29 + model = instance.get_eager_model().eval() + if args.weight: + weight = torch.load(args.weight, weights_only=True) + model.load_state_dict(weight) + assert args.calibration_number + if args.dataset: + inputs = DataManager.get_dataset( + data_dir=f"{args.dataset}", + calinum=args.calibration_number, + ) + else: + inputs = [instance.get_example_inputs() for _ in range(args.calibration_number)] + + test_in = inputs[0] + float_out = model(*test_in) + + compile_specs = [gen_samsung_backend_compile_spec(args.chipset)] + + if args.precision: + model = quantize_module( + model, inputs[0], inputs, getattr(Precision, args.precision) + ) + quant_out = model(*test_in) + + edge_prog = to_edge_transform_and_lower_to_enn( + model, inputs[0], compile_specs=compile_specs + ) + + edge = edge_prog.to_backend(EnnPartitioner(compile_specs)) + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(exec_prog, pte_filename, os.path.join(f"{args.artifact}")) + + if args.dump: + save_tensors(test_in, "float_in", args.artifact) + save_tensors(float_out, "float_out", args.artifact) + if args.precision: + save_tensors(quant_out, "quant_out", args.artifact)