Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Set, Type, Union

import torch
from executorch.backends.arm._passes.insert_table_ops import TableOps

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.backends.transforms.replace_scalar_with_tensor import (
Expand Down Expand Up @@ -64,7 +65,6 @@
Union[EdgeOpOverload, torch._ops.OpOverload],
] = _common_ops | {
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
}

_int_profile_ops: Dict[
Expand Down Expand Up @@ -101,7 +101,15 @@ def call_operator(self, op, args, kwargs, meta):
included_ops |= _fp_profile_ops

if included_ops == {}:
raise ValueError("Profile must support either INT or FP")
raise ValueError("Profile must support at least INT or FP")

if op in TableOps.included_ops():
# Do not handle quantized table ops; forward unchanged.
input_qparams = meta.data.get("input_qparams", {})
output_qparams = meta.data.get("input_qparams", {})
if len(input_qparams) > 0 and len(output_qparams) > 0:
# Do not handle; forward unchanged.
return ExportPass.call_operator(self, op, args, kwargs, meta)

if op in included_ops:
# Include this op based on the current profile.
Expand Down
36 changes: 28 additions & 8 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
return checker


def _is_integer_dtype(dtype: torch.dtype) -> bool:
return not dtype.is_floating_point and not dtype.is_complex


def _is_quantized_constant(node: torch.fx.Node) -> bool:
if node.target not in (
exir_ops.edge.aten.full_like.default,
Expand All @@ -161,7 +165,7 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool:
for user in users:
if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
dim_order_dtype = get_first_fake_tensor(user).dtype
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
if not _is_integer_dtype(dim_order_dtype):
return False
else:
return False
Expand All @@ -184,10 +188,24 @@ def is_quantized(node: torch.fx.Node) -> bool:
bool: True if the node is quantized, False otherwise.
"""

node_dtype = get_first_fake_tensor(node).dtype
# Integer-like dtype implies the node is already quantized.
if not node_dtype.is_complex and not node_dtype.is_floating_point:
return True
try:
node_dtype = get_first_fake_tensor(node).dtype
# Integer-like dtype implies the node is already quantized as long
# as inputs are not floating-point.
if _is_integer_dtype(node_dtype):
input_nodes = node.all_input_nodes
input_nodes_dtypes = [
get_first_fake_tensor(input_node).dtype for input_node in input_nodes
]
if all(
_is_integer_dtype(input_node_dtype)
for input_node_dtype in input_nodes_dtypes
):
return True

except TypeError:
# Could not determine dtype, fall back to other checks.
pass

# Nodes introduced during lowering that exclusively feed quantized users.
if _is_quantized_constant(node):
Expand Down Expand Up @@ -510,7 +528,7 @@ def is_node_supported(

input_quantized = input_quantized or all(
(input_node.target in DQ_OPS)
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
or _is_integer_dtype(get_first_fake_tensor(input_node).dtype)
for input_node in node.all_input_nodes
)

Expand All @@ -519,8 +537,10 @@ def is_node_supported(
return False

all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
output_quantized = output_quantized or all_q_users or not is_floating_point
output_dtype = get_first_fake_tensor(node).dtype
output_quantized = (
output_quantized or all_q_users or _is_integer_dtype(output_dtype)
)

if not output_quantized:
self.reporter.report_reject(node, "One or more outputs were not quantized.")
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def _match_pattern(
torch.ops.aten.hardswish.default,
torch.ops.aten.hardswish_.default,
torch.ops.aten.full_like.default,
torch.ops.aten.zeros_like.default,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.gelu.default,
torch.ops.aten.sinh.default,
Expand Down
17 changes: 13 additions & 4 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,22 @@ def get_vgf_compile_spec(

if not custom_path:
custom_path = maybe_get_tosa_collate_path()
profiles = []
if "FP" in repr(tosa_spec):
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_fp_")
elif "INT" in repr(tosa_spec):
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_int_")
else:
profiles.append("fp")
if "INT" in repr(tosa_spec):
profiles.append("int")
if len(profiles) == 0:
raise ValueError(f"Unsupported vgf compile_spec: {repr(tosa_spec)}")

if custom_path is None:
artifact_path = "arm_vgf_"
for profile in profiles:
artifact_path = artifact_path + f"_{profile}"
artifact_path = tempfile.mkdtemp(artifact_path)
else:
artifact_path = custom_path

if not os.path.exists(artifact_path):
os.makedirs(artifact_path, exist_ok=True)

Expand Down
16 changes: 12 additions & 4 deletions backends/arm/test/ops/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_eq_scalar_tosa_INT(test_module):


@common.parametrize("test_module", test_data_tensor)
def test_eq_tensor_tosa_INT_a16w8(test_module):
def test_eq_tensor_tosa_INT_16a8w(test_module):
pipeline = TosaPipelineINT[input_t](
test_module(),
test_module().get_inputs(),
Expand All @@ -134,7 +134,7 @@ def test_eq_tensor_tosa_INT_a16w8(test_module):


@common.parametrize("test_module", test_data_scalar)
def test_eq_scalar_tosa_INT_a16w8(test_module):
def test_eq_scalar_tosa_INT_16a8w(test_module):
pipeline = TosaPipelineINT[input_t](
test_module(),
test_module().get_inputs(),
Expand Down Expand Up @@ -238,7 +238,11 @@ def test_eq_scalar_16a8w_u85_INT16(test_module):
@common.SkipIfNoModelConverter
def test_eq_scalar_vgf_FP_tensor(test_module):
pipeline = VgfPipeline[input_t](
test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
test_module(),
test_module().get_inputs(),
Equal.aten_op_Tensor,
Equal.exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()

Expand All @@ -247,7 +251,11 @@ def test_eq_scalar_vgf_FP_tensor(test_module):
@common.SkipIfNoModelConverter
def test_eq_scalar_vgf_FP(test_module):
pipeline = VgfPipeline[input_t](
test_module(), test_module().get_inputs(), Equal.aten_op_Scalar, Equal.exir_op
test_module(),
test_module().get_inputs(),
Equal.aten_op_Scalar,
Equal.exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def __init__(
exir_op: Optional[str | List[str]] = None,
run_on_vulkan_runtime: bool = True,
vgf_compiler_flags: Optional[str] = "",
tosa_version: str = "TOSA-1.0+FP",
tosa_version: str = "TOSA-1.0+INT+FP",
symmetric_io_quantization: bool = False,
per_channel_quantization: bool = True,
use_to_edge_transform_and_lower: bool = True,
Expand Down
15 changes: 4 additions & 11 deletions backends/arm/vgf/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ def __init__(
tosa_spec (TosaSpecification | str | None): TOSA specification to
target. Strings are parsed via
:meth:`TosaSpecification.create_from_string`. Defaults to
``"TOSA-1.0+FP"``.
``"TOSA-1.0+FP+INT"``.
compiler_flags (list[str] | None): Optional converter-backend flags.

"""
if tosa_spec is None:
tosa_spec = "TOSA-1.0+FP"
if isinstance(tosa_spec, str):
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP+INT")
elif isinstance(tosa_spec, str):
tosa_spec = TosaSpecification.create_from_string(tosa_spec)

if compiler_flags is None:
Expand All @@ -55,13 +54,7 @@ def validate(self):

if "FP" not in tosa_profiles and "INT" not in tosa_profiles:
raise ValueError(
"Arm backend only supports converter-backend for FP or INT. "
f"Invalid TOSA profile: {tosa_profiles}"
)

if len(tosa_profiles) != 1:
raise ValueError(
"For now Arm backend only supports converter-backend for either FP or INT. "
"Arm backend only supports converter-backend for FP and/or INT. "
f"Invalid TOSA profile: {tosa_profiles}"
)

Expand Down
Loading