Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,19 +335,24 @@ def ops_to_not_decompose(
function that returns True when an op should not be decomposed.

"""
ops_to_not_decompose_if_quant_op = [
ops_to_not_decompose_if_quant_op = {
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.linear.default,
]
}
ops_to_not_decompose_if_fp = {
torch.ops.aten.linear.default,
}
ops_to_not_decompose_always = {
torch.ops.aten.eye.default,
torch.ops.aten.linspace.default,
torch.ops.aten.logit.default,
}

def filter_fn(node: torch.fx.Node) -> bool:
"""Return True to keep selected ops intact inside quantized regions.

The predicate holds when the target is in
``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
quantize/dequantize ops, indicating a quantized activation that
should not be decomposed.
"""Filter function applied to ops in 'ops_to_not_decompose'.
Returns True if the op should not be decomposed.
If this function returns True, the partitioner *must* accept the node, or the lowering fails.

Args:
node (torch.fx.Node): FX node to evaluate.
Expand All @@ -356,6 +361,12 @@ def filter_fn(node: torch.fx.Node) -> bool:
bool: True to keep the op intact; otherwise, False.

"""
if (
self.tosa_spec.support_float()
and node.target in ops_to_not_decompose_if_fp
):
return True

dq = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
Expand Down Expand Up @@ -394,11 +405,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
# By default, do not decompose the operator
return True

ops_to_not_decompose = [
torch.ops.aten.eye.default,
torch.ops.aten.linspace.default,
torch.ops.aten.logit.default,
] + ops_to_not_decompose_if_quant_op
ops_to_not_decompose = list(
ops_to_not_decompose_always
| ops_to_not_decompose_if_quant_op
| ops_to_not_decompose_if_fp
)

if not self.tosa_spec.is_U55_subset:
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
Expand Down
Loading