Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import ( # noqa
mean_dim_support,
right_shift_support,
to_copy_support,
tosa_supported_operators,
var_correction_support,
)
120 changes: 120 additions & 0 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import logging

import torch

import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)


@register_tosa_support_check
class ToCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten._to_copy.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

@staticmethod
def _merge_supported_types(
dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict
) -> SupportedTypeDict:
merged_dtypes = dtypes1
for k, v in dtypes2.items():
merged_dtypes[k] = merged_dtypes.get(k, []) + v
return merged_dtypes

SUPPORTED_INT_TYPES: SupportedTypeDict = {
torch.bool: [torch.int8, torch.int16, torch.int32],
torch.int8: [torch.bool, torch.int16, torch.int32],
torch.int16: [torch.bool, torch.int8, torch.int32],
torch.int32: [torch.bool, torch.int8, torch.int16],
}
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float32: [
torch.int8,
torch.int16,
torch.int32,
torch.bfloat16,
torch.float16,
],
}
ALL_SUPPORTED_TYPES = _merge_supported_types(
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
)
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
assert node.target in self.targets

if tosa_spec not in self.tosa_specs:
return False

assert tosa_spec.support_integer()
supported_dtypes = (
self.ALL_SUPPORTED_TYPES
if tosa_spec.support_float()
else self.SUPPORTED_INT_TYPES
)
# Take into account possible type conversions
supported_dtypes.update(
(k, supported_dtypes[v])
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
if v in supported_dtypes
)

# Check input type
assert len(node.all_input_nodes) == 1
input_val = node.all_input_nodes[0].meta["val"]
assert isinstance(input_val, torch._subclasses.FakeTensor)
input_dtype = input_val.dtype
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
f"{node.target.name()}."
)
return False

# Check output type
output_val = node.meta["val"]
assert isinstance(output_val, torch._subclasses.FakeTensor)
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target.name()} for input dtype {input_dtype}. "
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
)
return False

# Check memory format
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
f"Argument 'memory_format' is not supported for "
f"{node.target.name()} right now."
)
return False

return True
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
op_sub,
op_sum,
op_tanh,
op_to_copy,
op_transpose,
op_unsqueeze,
op_upsample_nearest2d,
Expand Down
43 changes: 43 additions & 0 deletions backends/arm/operators/op_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
import tosa.Op as TosaOp

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg


@register_node_visitor
class ToCopyVisitor(NodeVisitor):
"""
Implement the type cast functionality of _to_copy.

Other features like setting of the memory_format or moving a tensor to a
different device are not supported.

Also note that the node should not be quantized.
"""

target = "aten._to_copy.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert not is_quant_node, "Casting of quantized values is not supported."
assert inputs
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])
16 changes: 14 additions & 2 deletions backends/arm/test/ops/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,21 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
.run_method_and_compare_outputs(inputs=test_data)
)

# Most MI tests fail, just show one working for now.
@parameterized.expand((tensor_scalar_tests[6],))
@parameterized.expand(tensor_scalar_tests)
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
expected_exception = None
if any(token in test_name for token in ("Sub_int", "Sub__int")):
expected_exception = RuntimeError
elif test_name.endswith("_st"):
expected_exception = AttributeError

if expected_exception:
with self.assertRaises(
expected_exception, msg=f"Test {test_name} is expected to fail."
):
self._test_add_tosa_MI_pipeline(op, (x, y))
return

self._test_add_tosa_MI_pipeline(op, (x, y))

# op(Scalar float, tensor) works if the scalar is constant.
Expand Down
70 changes: 70 additions & 0 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Tests the _to_copy op which is interpreted as a cast for our purposes.
#

import unittest

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester

from parameterized import parameterized


class Cast(torch.nn.Module):
def __init__(self, target_dtype):
super().__init__()
self.target_dtype = target_dtype

def forward(self, x: torch.Tensor):
return x.to(dtype=self.target_dtype)


class TestToCopy(unittest.TestCase):
"""
Tests the _to_copy operation.

Only test unquantized graphs as explicit casting of dtypes messes with the
quantization.

Note: This is also covered by test_scalars.py.
"""

_TO_COPY_TEST_DATA = (
(torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32),
(torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16),
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.float32),
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.int32),
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int8),
)

def _test_to_copy_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: torch.Tensor
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.dump_artifact()
.check_count({"torch.ops.aten._to_copy.default": 1})
.to_edge()
.dump_artifact()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

@parameterized.expand(_TO_COPY_TEST_DATA)
def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_dtype):
self._test_to_copy_tosa_MI_pipeline(Cast(new_dtype), (test_tensor,))
Loading