Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def _merge_supported_types(
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
assert node.target in self.targets

supported_dtypes = (
self.ALL_SUPPORTED_TYPES
if tosa_spec.support_float()
Expand All @@ -90,10 +88,27 @@ def is_node_tosa_supported(
if v in supported_dtypes
)

# Check input type
assert len(node.all_input_nodes) == 1
if len(node.all_input_nodes) != 1:
self.reporter.report_reject(
node,
(
"Expected exactly one input node, "
f"got {len(node.all_input_nodes)} for {node.target}."
),
)
return False
input_val = node.all_input_nodes[0].meta["val"]
assert isinstance(input_val, torch._subclasses.FakeTensor)
if not isinstance(input_val, torch._subclasses.FakeTensor):
self.reporter.report_reject(
node,
(
"Invalid or missing meta: expected FakeTensor input, got "
f"{type(input_val).__name__} for {node.target}."
),
)
return False

# Check input type
input_dtype = input_val.dtype
if input_dtype not in supported_dtypes:
self.reporter.report_reject(
Expand All @@ -104,14 +119,24 @@ def is_node_tosa_supported(

# Check output type
output_val = node.meta["val"]
assert isinstance(output_val, torch._subclasses.FakeTensor)
if not isinstance(output_val, torch._subclasses.FakeTensor):
self.reporter.report_reject(
node,
(
"Invalid or missing meta: expected FakeTensor output, got "
f"{type(output_val).__name__} for {node.target}."
),
)
return False
if output_val.dtype not in supported_dtypes[input_dtype]:
self.reporter.report_reject(
node,
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target} for input dtype {input_dtype}. "
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
(
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target} for input dtype {input_dtype}. "
f"Supported output types: "
f"{', '.join(str(t) for t in supported_dtypes[input_dtype])}"
),
)
return False

Expand All @@ -120,8 +145,10 @@ def is_node_tosa_supported(
if node.kwargs["memory_format"] in (torch.preserve_format,):
self.reporter.report_reject(
node,
f"Argument 'memory_format' is not supported for "
f"{node.target} right now.",
(
"Argument 'memory_format' is not supported for "
f"{node.target} right now."
),
)
return False

Expand All @@ -132,8 +159,10 @@ def is_node_tosa_supported(
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
self.reporter.report_reject(
node,
f"Argument {dim_order=} is not supported for "
f"{node.target} right now.",
(
f"Argument {dim_order=} is not supported for "
f"{node.target} right now."
),
)
return False

Expand Down
Loading