Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def get_registered_tosa_support_checks(
) -> list[Type[SupportedTOSAOperatorCheck]]:

if tosa_spec not in _tosa_spec_support:
raise RuntimeError
raise RuntimeError(
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
)

return _tosa_spec_support[tosa_spec]

Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def define_node(
shape = input_node.shape
dim = dim.number
if end.number < 0:
end = end.number % shape[dim]
end_index = end.number % shape[dim]
else:
end = min(end.number, shape[dim])
size = end - start.number
end_index = min(end.number, shape[dim])
size = end_index - start.number
assert size > 0
assert size <= shape[dim]

Expand Down
75 changes: 55 additions & 20 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from executorch.backends.arm.tosa_utils import (
get_node_debug_info,
getNodeArgs,
tosa_shape,
)
from torch.export.exported_program import ExportedProgram


Expand All @@ -28,8 +32,13 @@ def process_call_function(
inputs = getNodeArgs(node)

# Convert output (this node itself)
output = TosaArg(node)

try:
output = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing call_function:\n{get_node_debug_info(node)}"
"Is the original torch function supported?"
) from e
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
)
Expand Down Expand Up @@ -61,15 +70,21 @@ def process_inputs(
f"Arm backend only supports contiguous memory format for inputs. "
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
)
inputs = [TosaArg(node)]
input_shape = inputs[0].shape
input_dim_order = inputs[0].dim_order
try:
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing input placeholder:\n{get_node_debug_info(node)}"
"Is the original torch function supported?"
) from e
input_shape = tosa_arg.shape
input_dim_order = tosa_arg.dim_order
tensor = ts.TosaSerializerTensor(
inputs[0].name,
tosa_arg.name,
tosa_shape(input_shape, input_dim_order),
inputs[0].dtype,
tosa_arg.dtype,
data=None,
placeholderFilename=inputs[0].name + ".npy",
placeholderFilename=tosa_arg.name + ".npy",
)
tosa_graph.addInputTensor(tensor)

Expand All @@ -81,20 +96,26 @@ def process_inputs_to_parameters(
tosa_spec: TosaSpecification,
):
"""Serialize bias and non-quantized weights"""
inputs = [TosaArg(node)]
parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
try:
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}"
"Is the original torch function supported?"
) from e
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
parameter_data = edge_program.state_dict[parameter_name]

assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
parameter_values = parameter_data.detach().numpy()

if inputs[0].dtype == torch.float32:
if tosa_arg.dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)

tosa_graph.addConst(
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
)


Expand All @@ -104,7 +125,13 @@ def process_inputs_to_buffers(
edge_program: ExportedProgram,
):
"""Serialize quantized weights"""
inputs = [TosaArg(node)]
try:
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}"
"Is the original torch function supported?"
) from e
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
buffer_data = edge_program.state_dict[buffer_name]

Expand All @@ -114,10 +141,10 @@ def process_inputs_to_buffers(
# TODO: fragile code for temporary fix
# the mean and var tensors are also stored here but they have shape (1, )
# we only transpose weights here
buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)

tosa_graph.addConst(
buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
)


Expand All @@ -126,14 +153,22 @@ def process_inputs_to_lifted_tensor_constants(
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
):
arg = TosaArg(node)
try:
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}"
"Is the original torch function supported?"
) from e
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
arg.name
tosa_arg.name
]
tensor = edge_program.tensor_constants[tensor_name]
tensor_data = tensor.detach().numpy()

tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
tosa_graph.addConst(
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
)


def process_placeholder(
Expand Down
25 changes: 13 additions & 12 deletions backends/arm/tosa_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@


def map_dtype(data_type):
assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}"
assert data_type in DTYPE_MAP, f"Unknown type: {data_type}"
if data_type in UNSUPPORTED_DTYPES:
raise ValueError(f"Unsupported type: {data_type}")
if data_type not in DTYPE_MAP:
raise ValueError(f"Unknown type: {data_type}")
return DTYPE_MAP[data_type]


Expand All @@ -58,7 +60,10 @@ def extract_tensor_meta(meta):
# TODO: should use first concrete representation
val = val[0]

assert torch._subclasses.fake_tensor.FakeTensor == type(val)
if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor):
raise ValueError(
f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}"
)
dtype = map_dtype(val.dtype)
shape = tuple(val.size())

Expand All @@ -71,19 +76,18 @@ def extract_tensor_meta(meta):

# Class to capture arguments and turn into tensor references for TOSA OPs
class TosaArg:
def __process_node(self, argument):
assert isinstance(argument, torch.fx.node.Node)
def __process_node(self, argument: torch.fx.Node):
self.name = argument.name
self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)

def __process_list(self, argument):
self.special = list(argument)

def __process_number(self, argument):
def __process_number(self, argument: float | int):
self.number = argument

def __init__(self, argument) -> None:
self.name = None
self.name = None # type: ignore[assignment]
self.dtype = None
self.shape = None
self.dim_order = None
Expand All @@ -92,16 +96,13 @@ def __init__(self, argument) -> None:
if argument is None:
return

if isinstance(argument, torch.fx.node.Node):
if isinstance(argument, torch.fx.Node):
self.__process_node(argument)
return
if isinstance(argument, list):
self.__process_list(argument)
return
if isinstance(argument, int):
self.__process_number(argument)
return
if isinstance(argument, float):
if isinstance(argument, (int, float)):
self.__process_number(argument)
return

Expand Down
39 changes: 26 additions & 13 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,28 @@
logger.setLevel(logging.INFO)


def dbg_node(node):
def dbg_node(node: torch.fx.Node):
# Debug output of node information
logger.info("OP")
logger.info(f" op is {node.op}")
logger.info(f" name is {node.name}")
logger.info(f" node target is {node.target}")
logger.info(f" node args is {node.args}")
logger.info(f" node kwargs is {node.kwargs}")
logger.info(" node.meta = ")
logger.info(get_node_debug_info(node))


def get_node_debug_info(node: torch.fx.Node) -> str:
output = (
"-- NODE DEBUG INFO --\n"
f" Op is {node.op}\n"
f" Name is {node.name}\n"
f" Node target is {node.target}\n"
f" Node args is {node.args}\n"
f" Node kwargs is {node.kwargs}\n"
f" Node users is {node.users}\n"
" Node.meta = \n"
)
for k, v in node.meta.items():
logger.info(f" '{k}' = {v}")
output += f" '{k}' = {v}\n"
if isinstance(v, list):
for i in v:
logger.info(f" {i} ")
output += f" {i}\n"
return output


# Output TOSA flatbuffer and test harness file
Expand All @@ -65,14 +73,19 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):

def dbg_fail(node, tosa_graph, path):
dbg_tosa_dump(tosa_graph, path)
logger.warn("Internal error due to poorly handled node:")
logger.warning("Internal error due to poorly handled node:")
dbg_node(node)
logger.warn(f"Debug output captured in '{path}'.")
logger.warning(f"Debug output captured in '{path}'.")
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")


def getNodeArgs(node: Node) -> list[TosaArg]:
return [TosaArg(arg) for arg in node.args]
try:
return [TosaArg(arg) for arg in node.args]
except ValueError as e:
raise ValueError(
f"Failed processing args to op:\n{get_node_debug_info(node)}"
) from e


def get_output_node(node: Node) -> Node:
Expand Down
Loading