Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _get_input_names(program: ExportedProgram) -> list[str]:


def _get_input_quantization_params(
program: ExportedProgram, input_names: list[str]
program: ExportedProgram,
) -> list[QuantizationParams]:
"""
Get input QuantizationParams in a program, maximum one per input to the program.
Expand All @@ -79,6 +79,7 @@ def _get_input_quantization_params(
"""

quant_params = []
input_names = _get_input_names(program)
num_inputs = len(input_names)
for node in program.graph.nodes:
if (
Expand Down Expand Up @@ -178,16 +179,19 @@ def __init__(

self._has_init_run = False

def init_run(self, exported_program: ExportedProgram, is_quantized: bool):
self.input_names = _get_input_names(exported_program)
def init_run(
self,
exported_program: ExportedProgram,
edge_program: ExportedProgram,
is_quantized: bool,
):
self.input_names = _get_input_names(edge_program)
self.output_node = _get_output_node(exported_program)
self.output_name = self.output_node.name
self.is_quantized = is_quantized

if is_quantized:
self.qp_input = _get_input_quantization_params(
exported_program, self.input_names
)
self.qp_input = _get_input_quantization_params(exported_program)
self.qp_output = _get_output_quantization_params(
exported_program, self.output_node
)
Expand Down Expand Up @@ -407,7 +411,7 @@ def prep_data_for_save(

if is_quantized:
assert (
quant_param.node_name == input_name
quant_param.node_name in input_name
), "These quantization params do not match the input tensor name"
data_np = (
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
Expand Down Expand Up @@ -500,7 +504,10 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
with open(tosa_input_file, "wb") as f:
f.write(tosa_fb)

tosa_schema_file = "./backends/arm/third-party/serialization_lib/schema/tosa.fbs"
arm_backend_path = os.path.realpath(os.path.dirname(__file__) + "/..")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but unrelated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's unrelated, I should probably have added several commits to the the PR for clarity

tosa_schema_file = os.path.join(
arm_backend_path, "third-party/serialization_lib/schema/tosa.fbs"
)
assert os.path.exists(
tosa_schema_file
), f"tosa_schema_file: {tosa_schema_file} does not exist"
Expand Down
19 changes: 9 additions & 10 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)

from executorch.backends.arm.test.runner_utils import (
_get_input_names,
_get_input_quantization_params,
_get_output_node,
_get_output_quantization_params,
Expand Down Expand Up @@ -241,15 +240,18 @@ def run_method_and_compare_outputs(
self.runner_util is not None
), "self.tosa_test_util is not initialized, cannot use run_method()"
assert (
self.stages[self.stage_name(tester.Export)] is not None
), "To compare outputs, at least the Export stage needs to be run."
self.stages[self.stage_name(tester.ToEdge)] is not None
), "To compare outputs, at least the ToEdge stage needs to be run."

stage = stage or self.cur
test_stage = self.stages[stage]
is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None
self.runner_util.init_run(
self.stages[self.stage_name(tester.Export)].artifact, is_quantized
)

exported_program = self.stages[self.stage_name(tester.Export)].artifact
edge_program = self.stages[
self.stage_name(tester.ToEdge)
].artifact.exported_program()
self.runner_util.init_run(exported_program, edge_program, is_quantized)

if is_quantized:
reference_stage = self.stages[self.stage_name(tester.Quantize)]
Expand Down Expand Up @@ -395,11 +397,8 @@ def _compare_outputs(
export_stage = self.stages.get(self.stage_name(tester.Export), None)
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
if export_stage is not None and quantize_stage is not None:
input_names = _get_input_names(export_stage.artifact)
output_node = _get_output_node(export_stage.artifact)
qp_input = _get_input_quantization_params(
export_stage.artifact, input_names
)
qp_input = _get_input_quantization_params(export_stage.artifact)
qp_output = _get_output_quantization_params(
export_stage.artifact, output_node
)
Expand Down
Loading