Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# backends. Converts via TOSA as an intermediate form supported by AoT and
# JIT compiler flows.
#
from enum import Enum
from typing import List, Optional

from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
Expand All @@ -22,12 +23,16 @@


class ArmCompileSpecBuilder:
class DebugMode(Enum):
JSON = 1

def __init__(self):
self.compile_spec: List[CompileSpec] = []
self.compiler_flags = []
self.output_format = None
self.path_for_intermediates = None
self.tosa_spec = None
self.tosa_debug_mode = None

def vgf_compile_spec(
self,
Expand Down Expand Up @@ -163,6 +168,13 @@ def dump_intermediate_artifacts_to(
self.path_for_intermediates = output_path
return self

def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder":
"""
Dump debugging information into the intermediates path
"""
self.tosa_debug_mode = debug_mode.name
return self

def build(self) -> List[CompileSpec]:
"""
Generate a list of compile spec objects from the builder
Expand All @@ -188,6 +200,16 @@ def build(self) -> List[CompileSpec]:
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
)

if self.tosa_debug_mode is not None:
if not self.path_for_intermediates:
raise ValueError(
"dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()"
)

self.compile_spec.append(
CompileSpec("dump_debug_info", self.tosa_debug_mode.encode())
)

return self.compile_spec


Expand Down
34 changes: 32 additions & 2 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

# pyre-unsafe

from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import torch

from executorch.backends.arm.debug.schema import DebugHook
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.export import ExportedProgram
Expand All @@ -29,9 +30,38 @@ class NodeVisitor:
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
def __init__(
self,
exported_program: ExportedProgram,
tosa_spec: TosaSpecification,
debug_hook: Optional[DebugHook] = None,
):
self._exported_program = exported_program
self.tosa_spec = tosa_spec
self.debug_hook = debug_hook

def _serialize_operator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend tosa_graph and not overload node visitor with this method?
The tosa_graph_torch (let's say) constructor takes debug_hook or other args and prepares a "graph", through addOperator like methods, which can be observed independently just like an FX graph.

Just an idea, don't want to block this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion Digant. I like the idea, I think it would look a bit cleaner

I can look into this, there will be one/two more commits to finalise this feature. So I think there is scope to incorporate this.

self,
node: torch.fx.Node,
tosa_graph: Any,
tosa_op: Any,
inputs: List[str],
outputs: List[str],
attributes: Optional[Any] = None,
) -> None:
tosa_graph.addOperator(
tosa_op,
inputs=inputs,
outputs=outputs,
attributes=attributes,
)

if self.debug_hook:
self.debug_hook.add(
node,
tosa_op=outputs[0],
tosa_op_id=tosa_op,
)

def define_node(
self,
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def define_node(
)

# MI lowering
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def define_node(
input1, input2 = rescaled_inputs

# Do the INT32 Add
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[add_output.name],
Expand Down Expand Up @@ -127,7 +129,9 @@ def define_node(
input1, input2 = inputs

# FP lowering
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[output.name],
Expand Down
9 changes: 7 additions & 2 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def define_node(

attr = ts.TosaSerializerAttribute()
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_MAX,
[input.name],
[output.name],
attr,
)
9 changes: 7 additions & 2 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def define_node(

attr = ts.TosaSerializerAttribute()
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_MIN,
[input.name],
[output.name],
attr,
)
9 changes: 7 additions & 2 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def define_node(
attr = ts.TosaSerializerAttribute()
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_ANY,
[inputs[0].name],
[output.name],
attr,
)
4 changes: 3 additions & 1 deletion backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def _build_generic_avgpool2d(
shape=[1], dtype=output.dtype, vals=[output_zp]
)

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().AVG_POOL2D,
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
[output.name],
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def define_node(
tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP")

# Add the MATMUL to the TOSA graph.
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().MATMUL,
[
inputs[0].name,
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def define_node(
attr = ts.TosaSerializerAttribute()
attr.ConcatAttribute(dim)

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().CONCAT,
[tensor.name for tensor in tensors],
[output.name],
Expand Down
18 changes: 14 additions & 4 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ def define_node(
nan_mode=1,
)

tosa_graph.addOperator(
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().CLAMP,
[inputs[0].name],
[output.name],
attr,
)


Expand Down Expand Up @@ -138,6 +143,11 @@ def define_node(
nan_mode=1,
)

tosa_graph.addOperator(
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().CLAMP,
[inputs[0].name],
[output.name],
attr,
)
4 changes: 3 additions & 1 deletion backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def define_node(
shape=[1], dtype=pad_const_dtype, vals=[pad_const_val]
)

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().PAD,
[inputs[0].name, padding.name, pad_const.name],
[output.name],
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def define_node(

reshape_attr = ts.TosaSerializerAttribute()
reshape_attr.ReshapeAttribute()
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().RESHAPE,
[weight.name, shape.name],
[weight_reshaped.name],
Expand Down Expand Up @@ -188,7 +190,9 @@ def define_node(
acc_type=acc_type,
)

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
tosa_op,
[
input.name,
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ def define_node(
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
)

tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name])
self._serialize_operator(
node, tosa_graph, ts.TosaOp.Op().COS, [inputs[0].name], [output.name]
)
6 changes: 4 additions & 2 deletions backends/arm/operators/op_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def define_node(
input_nodes = rescaled_inputs

# Do the equal comparison
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().EQUAL,
[input_nodes[0].name, input_nodes[1].name],
output.name,
[output.name],
None,
)
4 changes: 3 additions & 1 deletion backends/arm/operators/op_erf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ def define_node(
)

# MI lowering
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])
self._serialize_operator(
node, tosa_graph, ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]
)
4 changes: 3 additions & 1 deletion backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ def define_node(
output.tosa_spec,
)

tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name])
self._serialize_operator(
node, tosa_graph, ts.TosaOp.Op().EXP, [inputs[0].name], [output.name]
)
4 changes: 3 additions & 1 deletion backends/arm/operators/op_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def define_node(
# Update IO
input_nodes = rescaled_inputs

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().GREATER_EQUAL,
[input_nodes[0].name, input_nodes[1].name],
[output.name],
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def define_node(
# Update IO
input_nodes = rescaled_inputs

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().GREATER,
[input_nodes[0].name, input_nodes[1].name],
[output.name],
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def define_node(
tosa_graph, indices.name, indices_new_shape, indices_reshaped.name
)

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().GATHER,
[weights_reshaped.name, indices_reshaped.name],
[output_name],
Expand Down
12 changes: 9 additions & 3 deletions backends/arm/operators/op_index_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def define_node(
data = np.full(index_shape, int(values_strides[i] / C))
mul_const = tosa_graph.addConst(index_shape, index_dtype, data)
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift")
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().MUL,
[index_name, mul_const.name, f"{node.name}_{i}_shift"],
[stride_shifted_indices.name],
Expand All @@ -194,7 +196,9 @@ def define_node(
reshaped_idxs.shape,
reshaped_idxs.dtype,
)
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ADD,
[gather_index_name, reshaped_idxs.name],
[add_idxs.name],
Expand All @@ -217,7 +221,9 @@ def define_node(
gather_out_shape,
output.dtype,
)
tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().GATHER,
[reshaped_input.name, gather_index_name],
[gather_out.name],
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def define_node(
# Update IO
input_nodes = rescaled_inputs

tosa_graph.addOperator(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().GREATER_EQUAL,
[input_nodes[1].name, input_nodes[0].name],
[output.name],
Expand Down
Loading
Loading