Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ The Arm EthosU Backend should be considered a prototype quality at this point, l
## Current flows

The EthosUBackend has a two stage process,
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v1.0 TOSA INT with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
- Lower via the ethos-u-vela compilation flow which takes TOSA v1.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.

The EthosUPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future.

Expand Down
20 changes: 5 additions & 15 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module

def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down Expand Up @@ -162,7 +162,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeMaskedFill())
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeAcoshPass())
Expand Down Expand Up @@ -235,22 +235,12 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

return self._transform(exported_program.graph_module)

def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
return self._tosa_080_BI_pipeline(exported_program)

def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
return self._tosa_080_MI_pipeline(exported_program)

def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
return self._tosa_080_BI_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
return self._tosa_080_MI_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_1_0_fp_pipeline(exported_program)
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_FP_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
return self._tosa_1_0_int_quantized_pipeline(exported_program)
return self._tosa_INT_pipeline(exported_program)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.convolution.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/embedding_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.embedding.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/index_select_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.index_select.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/index_tensor_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.index.Tensor]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operator_support/minmax_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):

# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

Expand Down
4 changes: 0 additions & 4 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down Expand Up @@ -122,8 +120,6 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class SumSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.sum.dim_IntList]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operator_support/sin_cos_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class SinCosSupported(SupportedTOSAOperatorCheck):
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class SliceCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.slice_copy.Tensor]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def is_node_tosa_supported(

# container for all SupportedTosaOperatorCheck classes
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
}
Expand Down
11 changes: 1 addition & 10 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ class NodeVisitor:
# a specific TOSA version.
# When all node_visitors has been refactored to target a specific
# version, this list should be removed.
tosa_specs_1_00 = [
tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

tosa_specs_0_80 = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

tosa_specs = tosa_specs_0_80 + tosa_specs_1_00

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
self._exported_program = exported_program
self.tosa_spec = tosa_spec
Expand All @@ -52,8 +45,6 @@ def define_node(

# container for all node visitors
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
}
Expand Down
105 changes: 0 additions & 105 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,111 +23,6 @@
from torch.fx import Node


@register_node_visitor
class AbsVisitor_080_BI(NodeVisitor):
target = "aten.abs.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [*inputs, output], ts)
# Handle int8 (quantized) and int32
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
output.tosa_spec,
)

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
) # type: ignore[possibly-undefined]
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
abs_output = output

# Do the INT32 Abs
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[
rescaled_inputs[0].name,
],
[abs_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, abs_output, scale_back, node) # type: ignore[possibly-undefined]


@register_node_visitor
class AbsVisitor_080_MI(AbsVisitor_080_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [*inputs, output], ts)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Abs lowering
validate_valid_dtype(
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
)

# MI lowering
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
None,
)


@register_node_visitor
class AbsVisitor_INT(NodeVisitor):
target = "aten.abs.default"
Expand Down
Loading
Loading