diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 909be88f867..27c5e60bdaf 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -10,6 +10,7 @@ # backends. Converts via TOSA as an intermediate form supported by AoT and # JIT compiler flows. # +from enum import Enum from typing import List, Optional from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found] @@ -22,12 +23,16 @@ class ArmCompileSpecBuilder: + class DebugMode(Enum): + JSON = 1 + def __init__(self): self.compile_spec: List[CompileSpec] = [] self.compiler_flags = [] self.output_format = None self.path_for_intermediates = None self.tosa_spec = None + self.tosa_debug_mode = None def vgf_compile_spec( self, @@ -163,6 +168,13 @@ def dump_intermediate_artifacts_to( self.path_for_intermediates = output_path return self + def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder": + """ + Dump debugging information into the intermediates path + """ + self.tosa_debug_mode = debug_mode.name + return self + def build(self) -> List[CompileSpec]: """ Generate a list of compile spec objects from the builder @@ -188,6 +200,16 @@ def build(self) -> List[CompileSpec]: CompileSpec("debug_artifact_path", self.path_for_intermediates.encode()) ) + if self.tosa_debug_mode is not None: + if not self.path_for_intermediates: + raise ValueError( + "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" + ) + + self.compile_spec.append( + CompileSpec("dump_debug_info", self.tosa_debug_mode.encode()) + ) + return self.compile_spec diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index afc80bbb849..29d4fde1dd5 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -5,10 +5,11 @@ # pyre-unsafe -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch +from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram @@ -29,9 +30,38 @@ class NodeVisitor: TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + debug_hook: Optional[DebugHook] = None, + ): self._exported_program = exported_program self.tosa_spec = tosa_spec + self.debug_hook = debug_hook + + def _serialize_operator( + self, + node: torch.fx.Node, + tosa_graph: Any, + tosa_op: Any, + inputs: List[str], + outputs: List[str], + attributes: Optional[Any] = None, + ) -> None: + tosa_graph.addOperator( + tosa_op, + inputs=inputs, + outputs=outputs, + attributes=attributes, + ) + + if self.debug_hook: + self.debug_hook.add( + node, + tosa_op=outputs[0], + tosa_op_id=tosa_op, + ) def define_node( self, diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 3000af50ed7..33de1fa048f 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -123,7 +123,9 @@ def define_node( ) # MI lowering - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().ABS, [inputs[0].name], [output.name], diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 357800865cb..bd8440d6346 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -73,7 +73,9 @@ def define_node( input1, input2 = rescaled_inputs # Do the INT32 Add - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().ADD, [input1.name, input2.name], [add_output.name], @@ -127,7 +129,9 @@ def define_node( input1, input2 = inputs # FP lowering - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().ADD, [input1.name, input2.name], [output.name], diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 526d6ff35ec..626b6ec16ef 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -61,6 +61,11 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().REDUCE_MAX, + [input.name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 85b0b757c85..1e4093fd9b2 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -61,6 +61,11 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().REDUCE_MIN, + [input.name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 0ac307aedd4..ee5b2bdfdc6 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -52,6 +52,11 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().REDUCE_ANY, + [inputs[0].name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 9faf8272473..24c0e969a32 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -100,7 +100,9 @@ def _build_generic_avgpool2d( shape=[1], dtype=output.dtype, vals=[output_zp] ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().AVG_POOL2D, [input_tensor.name, input_zp_tensor.name, output_zp_tensor.name], [output.name], diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index c9bb0b003ee..ce795af2261 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -78,7 +78,9 @@ def define_node( tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP") # Add the MATMUL to the TOSA graph. - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MATMUL, [ inputs[0].name, diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 884bfb22a40..9c03cdaa3b8 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -47,7 +47,9 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ConcatAttribute(dim) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 2bdeb89a713..e61a875cebe 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -90,8 +90,13 @@ def define_node( nan_mode=1, ) - tosa_graph.addOperator( - ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().CLAMP, + [inputs[0].name], + [output.name], + attr, ) @@ -138,6 +143,11 @@ def define_node( nan_mode=1, ) - tosa_graph.addOperator( - ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().CLAMP, + [inputs[0].name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 147a1544ce9..8e39420f88f 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -100,7 +100,9 @@ def define_node( shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().PAD, [inputs[0].name, padding.name, pad_const.name], [output.name], diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 0bbe67c4beb..be7e6be2fbe 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -157,7 +157,9 @@ def define_node( reshape_attr = ts.TosaSerializerAttribute() reshape_attr.ReshapeAttribute() - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().RESHAPE, [weight.name, shape.name], [weight_reshaped.name], @@ -188,7 +190,9 @@ def define_node( acc_type=acc_type, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, tosa_op, [ input.name, diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index e37db290d55..ae8c6d86d01 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -44,4 +44,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().COS, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index eb5b3000d6c..9d32348a25b 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -68,9 +68,11 @@ def define_node( input_nodes = rescaled_inputs # Do the equal comparison - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().EQUAL, [input_nodes[0].name, input_nodes[1].name], - output.name, + [output.name], None, ) diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e238c4fd80a..b46bc55e873 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -48,4 +48,6 @@ def define_node( ) # MI lowering - tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().ERF, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 96c077c838b..adff20ae6d5 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -48,4 +48,6 @@ def define_node( output.tosa_spec, ) - tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().EXP, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 723706702f0..218651b74c1 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -67,7 +67,9 @@ def define_node( # Update IO input_nodes = rescaled_inputs - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, [input_nodes[0].name, input_nodes[1].name], [output.name], diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index e79ed009e24..29c0717a0a1 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -67,7 +67,9 @@ def define_node( # Update IO input_nodes = rescaled_inputs - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GREATER, [input_nodes[0].name, input_nodes[1].name], [output.name], diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index a42f85abc4c..a925f5ee20b 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -86,7 +86,9 @@ def define_node( tosa_graph, indices.name, indices_new_shape, indices_reshaped.name ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GATHER, [weights_reshaped.name, indices_reshaped.name], [output_name], diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 7afd7fe6612..598a3eea7aa 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -167,7 +167,9 @@ def define_node( data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MUL, [index_name, mul_const.name, f"{node.name}_{i}_shift"], [stride_shifted_indices.name], @@ -194,7 +196,9 @@ def define_node( reshaped_idxs.shape, reshaped_idxs.dtype, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().ADD, [gather_index_name, reshaped_idxs.name], [add_idxs.name], @@ -217,7 +221,9 @@ def define_node( gather_out_shape, output.dtype, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GATHER, [reshaped_input.name, gather_index_name], [gather_out.name], diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 9301f91cb4c..56cc67d4298 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -67,7 +67,9 @@ def define_node( # Update IO input_nodes = rescaled_inputs - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, [input_nodes[1].name, input_nodes[0].name], [output.name], diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 8a48fe4fda5..9a68de66f9a 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -45,4 +45,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().LOG, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 31083e93590..89d745d4759 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -67,7 +67,9 @@ def define_node( # Update IO input_nodes = rescaled_inputs - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().GREATER, [input_nodes[1].name, input_nodes[0].name], [output.name], diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 754fcfcd638..01b4c8f5521 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -94,7 +94,9 @@ def define_node( kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MAX_POOL2D, [input_tensor.name], [output.name], diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 27e5fdc2e02..45ca36dd8fc 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -87,7 +87,9 @@ def define_node( # Set to PROPOGATE as default attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MAXIMUM, [ operand_inputs[0].name, diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 9dfa7d1f394..e91053d741a 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -86,7 +86,9 @@ def define_node( # Set to PROPOGATE as default attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MINIMUM, [ operand_inputs[0].name, diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 7d9f6eac6aa..5ea86750b0e 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -93,7 +93,9 @@ def define_node( # Do the INT32 Mul tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MUL, [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], [mul_output.name], @@ -135,7 +137,9 @@ def define_node( input1, input2 = inputs tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().MUL, [input1.name, input2.name, f"{node.name}_shift"], [output.name], diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 54f3dafe769..90dda965f2d 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -82,7 +82,9 @@ def define_node( (1,), output.dtype, [output_zp], name=output.name + "_output_zp" ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().NEGATE, [inputs[0].name, input_zp_tensor.name, output_zp_tensor.name], [output.name], diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 0830d8f4504..e81e33949de 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -135,6 +135,11 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(permutation_vector) - tosa_graph.addOperator( - ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().TRANSPOSE, + [inputs[0].name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 413160c902a..027cfaf4adf 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -50,7 +50,9 @@ def define_node( output.tosa_spec, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().POW, [ inputs[0].name, diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 3838afd9728..811186731b4 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -46,6 +46,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator( - ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 3e636e993b7..6c569ae1325 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -60,7 +60,9 @@ def define_node( name=node.name + "_multiples", ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().TILE, [inputs[0].name, multiple_shapes.name], [output.name], diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 5313f5c8143..4749df49f01 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -53,7 +53,9 @@ def define_node( round = True attr.ArithmeticRightShiftAttribute(round=round) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, [inputs[0].name, inputs[1].name], [output.name], diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index df293946ded..89c60e22239 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -46,4 +46,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index dec42ae15f9..fdc305d0fc7 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -45,4 +45,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index 1ea637f960a..da3453f7850 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -44,4 +44,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().SIN, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().SIN, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 56115073ce1..07b31c4c05f 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -117,7 +117,9 @@ def define_node( (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().SLICE, [input_node.name, start_tensor.name, sizes_tensor.name], [output.name], diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 4701c488967..2dd4b4fe854 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -72,7 +72,9 @@ def define_node( sub_output = output # Do the INT32 Sub - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().SUB, [ rescaled_inputs[0].name, @@ -127,7 +129,9 @@ def define_node( ) # MI lowering - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().SUB, [inputs[0].name, inputs[1].name], [output.name], diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 4886a513881..7931ba9a1ca 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -60,7 +60,9 @@ def define_node( name=table_tensor_name, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().TABLE, [inputs[0].name, table_tensor_name], [output.name], diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 0d149397eb6..7d5971e59f9 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -46,4 +46,6 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().TANH, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index 9758a018b87..2ee09b2496b 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -44,4 +44,6 @@ def define_node( validate_num_inputs(self.target, inputs, 1) - tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().CAST, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 74bf1a5ad14..cd5c45d459e 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -44,4 +44,6 @@ def define_node( validate_num_inputs(self.target, inputs, 1) - tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().CAST, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index accd79e8546..7bd4be0c4f4 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -62,6 +62,11 @@ def define_node( perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(perms) - tosa_graph.addOperator( - ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().TRANSPOSE, + [inputs[0].name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 26927bfcfa2..f7e9d17ed96 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -98,7 +98,9 @@ def in_int16_range(x): [len(border)], ts.DType.SHAPE, border, node.name + "_border" ) if input_dtype == output.dtype == ts.DType.FP32: - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().RESIZE, [ inputs[0].name, @@ -114,7 +116,9 @@ def in_int16_range(x): intermediate = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().RESIZE, [ inputs[0].name, diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index 46dcc0605e6..3b93cd75a1a 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -89,7 +89,9 @@ def in_int16_range(x): mode=ResizeMode.NEAREST, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().RESIZE, [ inputs[0].name, diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 1e8c06b691f..9535e448a1b 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -66,6 +66,11 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute() - tosa_graph.addOperator( - ts.TosaOp.Op().RESHAPE, [inputs[0].name, shape.name], [output.name], attr + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().RESHAPE, + [inputs[0].name, shape.name], + [output.name], + attr, ) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index e6a87be6387..2ec26ab63da 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -33,6 +33,7 @@ def __init__(self, *args): def _add_node_to_tosa_graph( self, + node: Node, tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, @@ -51,7 +52,9 @@ def _add_node_to_tosa_graph( output.tosa_spec, ) - tosa_graph.addOperator( + self._serialize_operator( + node, + tosa_graph, ts.TosaOp.Op().SELECT, [inputs[0].name, inputs[1].name, inputs[2].name], [output.name], @@ -73,7 +76,9 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) + self._add_node_to_tosa_graph( + node, tosa_graph, inputs, output, bi_supported_dtypes + ) @register_node_visitor @@ -103,4 +108,6 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) + self._add_node_to_tosa_graph( + node, tosa_graph, inputs, output, mi_supported_dtypes + ) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index dc9bd446a34..2a8ac9582d3 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -65,8 +65,12 @@ def define_node( output.tosa_spec, ) - tosa_graph.addOperator( - tosa_op, [inputs[0].name, inputs[1].name], [output.name] + self._serialize_operator( + node, + tosa_graph, + tosa_op, + [inputs[0].name, inputs[1].name], + [output.name], ) register_node_visitor(BinaryOperator) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 238b033f8eb..94bd9605706 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -45,8 +45,12 @@ def define_node( validate_same_dtype(self.target, [*inputs, output], ts) # Simply add an identityOp - tosa_graph.addOperator( - ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().IDENTITY, + [inputs[0].name], + [output.name], ) register_node_visitor(IdentityOperatorVisitor) diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 48092e13968..4ccebc2d467 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -54,7 +54,9 @@ def define_node( output.tosa_spec, ) - tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name]) + self._serialize_operator( + node, tosa_graph, tosa_op, [inputs[0].name], [output.name] + ) register_node_visitor(UnaryOperator) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index b01dec4d371..059c86de351 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -63,16 +63,24 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( - tosa_spec: str | TosaSpecification, custom_path=None + tosa_spec: str | TosaSpecification, + custom_path: Optional[str] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, ) -> list[CompileSpec]: """ Default compile spec for TOSA tests. """ - return get_tosa_compile_spec_unbuilt(tosa_spec, custom_path).build() + return get_tosa_compile_spec_unbuilt( + tosa_spec, + custom_path, + tosa_debug_mode, + ).build() def get_tosa_compile_spec_unbuilt( - tosa_spec: str | TosaSpecification, custom_path=None + tosa_spec: str | TosaSpecification, + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. @@ -82,12 +90,16 @@ def get_tosa_compile_spec_unbuilt( if custom_path is not None: os.makedirs(custom_path, exist_ok=True) + compile_spec_builder = ( ArmCompileSpecBuilder() .tosa_compile_spec(tosa_spec) .dump_intermediate_artifacts_to(custom_path) ) + if tosa_debug_mode is not None: + compile_spec_builder.dump_debug_info(tosa_debug_mode) + return compile_spec_builder @@ -97,6 +109,7 @@ def get_u55_compile_spec( memory_mode: str = "Shared_Sram", extra_flags: str = "--debug-force-regor --output-format=raw", custom_path: Optional[str] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, config: Optional[str] = "Arm/vela.ini", ) -> list[CompileSpec]: """ @@ -108,16 +121,18 @@ def get_u55_compile_spec( memory_mode=memory_mode, extra_flags=extra_flags, custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, config=config, ).build() def get_u85_compile_spec( macs: int = 128, - system_config="Ethos_U85_SYS_DRAM_Mid", - memory_mode="Shared_Sram", - extra_flags="--output-format=raw", - custom_path=None, + system_config: str = "Ethos_U85_SYS_DRAM_Mid", + memory_mode: str = "Shared_Sram", + extra_flags: str = "--output-format=raw", + custom_path: Optional[str] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, config: Optional[str] = "Arm/vela.ini", ) -> list[CompileSpec]: """ @@ -129,6 +144,7 @@ def get_u85_compile_spec( memory_mode=memory_mode, extra_flags=extra_flags, custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, config=config, ).build() @@ -136,12 +152,15 @@ def get_u85_compile_spec( def get_vgf_compile_spec( tosa_spec: str | TosaSpecification, compiler_flags: Optional[str] = "", - custom_path=None, + custom_path: Optional[str] = "", + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, ) -> list[CompileSpec]: """ Default compile spec for VGF tests. """ - return get_vgf_compile_spec_unbuilt(tosa_spec, compiler_flags, custom_path).build() + return get_vgf_compile_spec_unbuilt( + tosa_spec, compiler_flags, custom_path, tosa_debug_mode + ).build() def get_u55_compile_spec_unbuilt( @@ -150,6 +169,7 @@ def get_u55_compile_spec_unbuilt( memory_mode: str, extra_flags: str, custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], config: Optional[str], ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify @@ -173,6 +193,10 @@ def get_u55_compile_spec_unbuilt( ) .dump_intermediate_artifacts_to(artifact_path) ) + + if tosa_debug_mode is not None: + compile_spec.dump_debug_info(tosa_debug_mode) + return compile_spec @@ -182,6 +206,7 @@ def get_u85_compile_spec_unbuilt( memory_mode: str, extra_flags: str, custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], config: Optional[str], ) -> list[CompileSpec]: """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify @@ -204,13 +229,18 @@ def get_u85_compile_spec_unbuilt( ) .dump_intermediate_artifacts_to(artifact_path) ) + + if tosa_debug_mode is not None: + compile_spec.dump_debug_info(tosa_debug_mode) + return compile_spec # type: ignore[return-value] def get_vgf_compile_spec_unbuilt( tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str] = "", - custom_path=None, + compiler_flags: Optional[str], + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the default VGF tests, to modify the compile spec before calling .build() to finalize it. @@ -231,6 +261,9 @@ def get_vgf_compile_spec_unbuilt( .dump_intermediate_artifacts_to(artifact_path) ) + if tosa_debug_mode is not None: + compile_spec_builder.dump_debug_info(tosa_debug_mode) + return compile_spec_builder diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 288d5b41615..5648070f869 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -3,15 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json import os import shutil import tempfile +from pathlib import Path from typing import Tuple import pytest import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -188,6 +191,37 @@ def test_collate_tosa_INT_tests(test_data: input_t1): shutil.rmtree("test_collate_tosa_tests", ignore_errors=True) +@common.parametrize("test_data", Linear.inputs) +def test_dump_tosa_debug_json(test_data: input_t1): + with tempfile.TemporaryDirectory() as tmpdir: + pipeline = TosaPipelineINT[input_t1]( + module=Linear(), + test_data=test_data, + aten_op=[], + exir_op=[], + custom_path=tmpdir, + tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.JSON, + ) + + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + json_output_path = Path(tmpdir) / "debug.json" + + # The file should exist + assert json_output_path.exists() + + # Check the file is valid JSON and can be loaded + with json_output_path.open("r") as file: + try: + data = json.load(file) + + # Check it's not empty + assert data + except json.JSONDecodeError: + pytest.fail("Failed to load debug JSON file") + + @common.parametrize("test_data", Linear.inputs) def test_dump_tosa_ops(caplog, test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index b9c01e195b2..941c875f55a 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -22,6 +22,7 @@ import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -339,6 +340,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -355,7 +357,9 @@ def __init__( tosa_version = conftest.get_option("tosa_version") compile_spec = common.get_tosa_compile_spec( - tosa_profiles[tosa_version], custom_path=custom_path + tosa_profiles[tosa_version], + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, ) quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) @@ -441,6 +445,7 @@ def __init__( run_on_tosa_ref_model: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 0, @@ -460,7 +465,9 @@ def __init__( tosa_version = conftest.get_option("tosa_version") compile_spec = common.get_tosa_compile_spec( - tosa_profiles[tosa_version], custom_path=custom_path + tosa_profiles[tosa_version], + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, ) super().__init__( module, @@ -519,11 +526,15 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, ): - compile_spec = common.get_u55_compile_spec(custom_path=custom_path) + compile_spec = common.get_u55_compile_spec( + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, + ) quantizer = EthosUQuantizer(compile_spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization @@ -606,11 +617,15 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, ): - compile_spec = common.get_u85_compile_spec(custom_path=custom_path) + compile_spec = common.get_u85_compile_spec( + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, + ) quantizer = EthosUQuantizer(compile_spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization @@ -915,6 +930,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -931,7 +947,10 @@ def __init__( tosa_version + "".join([f"+{ext}" for ext in tosa_extensions]) ) compile_spec = common.get_vgf_compile_spec( - tosa_spec, compiler_flags=vgf_compiler_flags, custom_path=custom_path + tosa_spec, + compiler_flags=vgf_compiler_flags, + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, ) super().__init__( diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index 7062d68b944..3f35679cc30 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -14,6 +14,7 @@ from typing import cast, final, List import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.tosa_specification import get_tosa_spec from executorch.backends.arm._passes import ( @@ -62,6 +63,7 @@ def preprocess( # noqa: C901 artifact_path = None output_format = "" compile_flags = [] + dump_debug_info = None for spec in compile_spec: if spec.key == "debug_artifact_path": artifact_path = spec.value.decode() @@ -69,6 +71,8 @@ def preprocess( # noqa: C901 output_format = spec.value.decode() if spec.key == "compile_flags": compile_flags.append(spec.value.decode()) + if spec.key == "dump_debug_info": + dump_debug_info = spec.value.decode() # Check that the output format is set correctly in the compile spec if output_format != "tosa": @@ -95,7 +99,11 @@ def preprocess( # noqa: C901 exported_program=edge_program ) - node_visitors = get_node_visitors(edge_program, tosa_spec) + debug_hook = None + if dump_debug_info is not None: + debug_hook = DebugHook() + + node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) @@ -126,6 +134,11 @@ def preprocess( # noqa: C901 suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), ) + if debug_hook: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) + # Serialize and return the TOSA flatbuffer. binary = bytes(tosa_graph.serialize())