Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS

# Add all targets and TOSA profiles we support here.
TARGETS = {"tosa_BI", "tosa_MI", "u55_BI", "u85_BI"}
TARGETS = ["tosa_MI", "tosa_BI", "u55_BI", "u85_BI"]


def get_edge_ops():
def get_op_name_map():
"""
Returns a set with edge_ops with names on the form to be used in unittests:
Returns a mapping from names on the form to be used in unittests to edge op:
1. Names are in lowercase.
2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
2. Overload is ignored if 'default', otherwise it's appended with an underscore.
3. Overly verbose name are shortened by removing certain prefixes/suffixes.

Examples:
abs.default -> abs
split_copy.Tensor -> split_tensor
"""
edge_ops = set()
map = {}
for edge_name in ALL_EDGE_OPS:
op, overload = edge_name.split(".")

Expand All @@ -44,21 +44,24 @@ def get_edge_ops():
overload = overload.lower()

if overload == "default":
edge_ops.add(op)
map[op] = edge_name
else:
edge_ops.add(f"{op}_{overload}")
map[f"{op}_{overload}"] = edge_name

return edge_ops
return map


def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]:
def parse_test_name(
test_name: str, op_name_map: dict[str, str]
) -> tuple[str, str, bool]:
"""
Parses a test name on the form
test_OP_TARGET_<not_delegated>_<any_other_info>
where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
The "not_delegated" suffix indicates that the test tests that the op is not delegated.
where OP must match a key in op_name_map and TARGET one string in TARGETS. The
"not_delegated" suffix indicates that the test tests that the op is not delegated.

Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
Examples of valid names: "test_mm_u55_BI_not_delegated" and
"test_add_scalar_tosa_MI_two_inputs".

Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
"""
Expand All @@ -82,7 +85,7 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]

assert target != "None", f"{test_name} does not contain one of {TARGETS}"
assert (
op in edge_ops
op in op_name_map.keys()
), f"Parsed unvalid OP from {test_name}, {op} does not exist in edge.yaml or CUSTOM_EDGE_OPS"

return op, target, is_delegated
Expand All @@ -94,13 +97,13 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]

sys.tracebacklimit = 0 # Do not print stack trace

edge_ops = get_edge_ops()
op_name_map = get_op_name_map()
exit_code = 0

for test_name in sys.argv[1:]:
try:
assert test_name[:5] == "test_", f"Unexpected input: {test_name}"
parse_test_name(test_name, edge_ops)
parse_test_name(test_name, op_name_map)
except AssertionError as e:
print(e)
exit_code = 1
Expand Down
Loading