diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 6e8b9b25ede..8e7edecca7d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -66,7 +66,7 @@ def _get_input_names(program: ExportedProgram) -> list[str]: def _get_input_quantization_params( - program: ExportedProgram, input_names: list[str] + program: ExportedProgram, ) -> list[QuantizationParams]: """ Get input QuantizationParams in a program, maximum one per input to the program. @@ -79,6 +79,7 @@ def _get_input_quantization_params( """ quant_params = [] + input_names = _get_input_names(program) num_inputs = len(input_names) for node in program.graph.nodes: if ( @@ -178,16 +179,19 @@ def __init__( self._has_init_run = False - def init_run(self, exported_program: ExportedProgram, is_quantized: bool): - self.input_names = _get_input_names(exported_program) + def init_run( + self, + exported_program: ExportedProgram, + edge_program: ExportedProgram, + is_quantized: bool, + ): + self.input_names = _get_input_names(edge_program) self.output_node = _get_output_node(exported_program) self.output_name = self.output_node.name self.is_quantized = is_quantized if is_quantized: - self.qp_input = _get_input_quantization_params( - exported_program, self.input_names - ) + self.qp_input = _get_input_quantization_params(exported_program) self.qp_output = _get_output_quantization_params( exported_program, self.output_node ) @@ -407,7 +411,7 @@ def prep_data_for_save( if is_quantized: assert ( - quant_param.node_name == input_name + quant_param.node_name in input_name ), "These quantization params do not match the input tensor name" data_np = ( ((data_np / np.float32(quant_param.scale)) + quant_param.zp) @@ -500,7 +504,10 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: with open(tosa_input_file, "wb") as f: f.write(tosa_fb) - tosa_schema_file = "./backends/arm/third-party/serialization_lib/schema/tosa.fbs" + arm_backend_path = os.path.realpath(os.path.dirname(__file__) + "/..") + tosa_schema_file = os.path.join( + arm_backend_path, "third-party/serialization_lib/schema/tosa.fbs" + ) assert os.path.exists( tosa_schema_file ), f"tosa_schema_file: {tosa_schema_file} does not exist" diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 2fe8c07e7d1..994313a1ff0 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -23,7 +23,6 @@ ) from executorch.backends.arm.test.runner_utils import ( - _get_input_names, _get_input_quantization_params, _get_output_node, _get_output_quantization_params, @@ -241,15 +240,18 @@ def run_method_and_compare_outputs( self.runner_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" assert ( - self.stages[self.stage_name(tester.Export)] is not None - ), "To compare outputs, at least the Export stage needs to be run." + self.stages[self.stage_name(tester.ToEdge)] is not None + ), "To compare outputs, at least the ToEdge stage needs to be run." stage = stage or self.cur test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None - self.runner_util.init_run( - self.stages[self.stage_name(tester.Export)].artifact, is_quantized - ) + + exported_program = self.stages[self.stage_name(tester.Export)].artifact + edge_program = self.stages[ + self.stage_name(tester.ToEdge) + ].artifact.exported_program() + self.runner_util.init_run(exported_program, edge_program, is_quantized) if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] @@ -395,11 +397,8 @@ def _compare_outputs( export_stage = self.stages.get(self.stage_name(tester.Export), None) quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None) if export_stage is not None and quantize_stage is not None: - input_names = _get_input_names(export_stage.artifact) output_node = _get_output_node(export_stage.artifact) - qp_input = _get_input_quantization_params( - export_stage.artifact, input_names - ) + qp_input = _get_input_quantization_params(export_stage.artifact) qp_output = _get_output_quantization_params( export_stage.artifact, output_node )