Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,11 @@ def is_node_supported(
if not input_quantized:
return False

output_quantized = output_quantized or all(
(output_node.target == self.q_op)
or (not get_first_fake_tensor(output_node).dtype.is_floating_point)
for output_node in node.users
all_q_users = all(
(output_node.target == self.q_op) for output_node in node.users
)
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
output_quantized = output_quantized or all_q_users or not is_floating_point

if not output_quantized:
return False
Expand Down
119 changes: 112 additions & 7 deletions backends/arm/test/misc/test_partition_decomposed_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,39 @@
)

input_t1 = Tuple[torch.Tensor]
aten_op: list[str] = ["torch.ops.aten.add.Tensor", "torch.ops.aten.softplus.default"]
exir_op: list[str] = [
softplus_aten_op: list[str] = [
"torch.ops.aten.add.Tensor",
"torch.ops.aten.softplus.default",
]
softplus_exir_op: list[str] = [
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_exp_default",
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
]

linear_residual_aten_op: list[str] = [
"torch.ops.aten.linear.default",
"torch.ops.aten.gelu.default",
"torch.ops.aten.dropout.default",
"torch.ops.aten.add.Tensor",
]
linear_residual_exir_op: list[str] = [
"executorch_exir_dialects_edge__ops_aten_gelu_default",
"executorch_exir_dialects_edge__ops_aten_clone_default",
"executorch_exir_dialects_edge__ops_aten_linear_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
]


test_data: dict[input_t1] = {
"3d_rand": (torch.rand(1, 5, 5),),
}


class Module(torch.nn.Module):
class SoftplusModule(torch.nn.Module):
"""Module containing an addition followed by a Softplus. Softplus is currently not supported by TosaBackend."""

def __init__(self):
super().__init__()
self.softplus = torch.nn.Softplus()
Expand All @@ -42,10 +60,35 @@ def forward(self, x: torch.Tensor):
return self.softplus(x + x)


class LinearResidualModule(torch.nn.Module):
"""Module containing a residual and a linear layer followed by GELU and a Dropout.
GELU is currently not supported by TosaBackend nor TosaQuantizer.
"""

def __init__(
self,
):
super().__init__()
self.linear = torch.nn.Linear(in_features=5, out_features=3)
self.gelu = torch.nn.GELU()
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x: torch.Tensor):
x1 = self.linear(x)
x2 = self.gelu(x1)
x3 = self.dropout(x2)
return x1 + x3


# Softplus is decomposed which messes up the quantization. This test tests that CheckProperQuantization does not
# partition nodes where quantization is not as expected.
@common.parametrize("test_data", test_data)
def test_softplus_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
SoftplusModule(),
test_data=test_data,
aten_op=softplus_aten_op,
exir_op=softplus_exir_op,
)
# remove check_count.exir as there will be more than one delegate
pipeline.pop_stage("check_count.exir")
Expand All @@ -55,14 +98,76 @@ def test_softplus_tosa_MI(test_data: input_t1):
@common.parametrize("test_data", test_data)
def test_softplus_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
SoftplusModule(),
test_data=test_data,
aten_op=softplus_aten_op,
exir_op=softplus_exir_op,
)
pipeline.pop_stage("check_not.exir")
# check that all ops in softplus_exir_op except add are rejected
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check,
softplus_exir_op[1:],
suffix="exir_post_partition",
)
pipeline.run()


# Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either.
# If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the
# partitioner indeed does not partition the Dropout (clone) for TOSA BI.
@common.parametrize("test_data", test_data)
def test_linear_residaul_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
LinearResidualModule(),
test_data=test_data,
aten_op=linear_residual_aten_op,
exir_op=linear_residual_exir_op,
use_to_edge_transform_and_lower=True,
)
# remove check_count.exir as there will be more than one delegate
pipeline.pop_stage("check_count.exir")
pipeline.pop_stage("check_not.exir")
# check that all ops in linear_residual_exir_op except GELU are partitioned
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check_not,
linear_residual_exir_op[1:],
suffix="exir_post_partition",
)
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check,
linear_residual_exir_op[:1],
suffix="exir_post_partition",
)
pipeline.run()


@common.parametrize("test_data", test_data)
def test_linear_residual_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
LinearResidualModule(),
test_data=test_data,
aten_op=linear_residual_aten_op,
exir_op=linear_residual_exir_op,
use_to_edge_transform_and_lower=True,
)
# remove check_count.exir as there will be more than one delegate
pipeline.pop_stage("check_count.exir")
pipeline.pop_stage("check_not.exir")
# check that all ops in exir_op except add are rejected
# check that all ops in linear_residual_exir_op except GELU and Dropout are partitioned
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check_not,
linear_residual_exir_op[2:],
suffix="exir_post_partition",
)
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check,
exir_op[1:],
linear_residual_exir_op[:2],
suffix="exir_post_partition",
)
pipeline.run()
17 changes: 16 additions & 1 deletion backends/arm/tosa_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_tosa_spec,
is_tosa,
) # usort: skip
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operator_support.tosa_supported_operators import (
tosa_support_factory,
)
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec)
self.additional_checks = additional_checks

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags

Expand Down Expand Up @@ -110,6 +111,20 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
del node.meta["delegation_tag"]
break

if tosa_spec.support_float():
continue

if is_partitioned(node):
for input in node.all_input_nodes:
if is_partitioned(input):
continue
if get_first_fake_tensor(input).dtype.is_floating_point:
logger.info(
f"Not partitioning {node.name} becuase input {input.name} has floating point dtype."
)
del node.meta["delegation_tag"]
break

tag_constant_data(exported_program)

return PartitionResult(
Expand Down
Loading