From aa98905f8bd1d60b95111c48ac77e3831915722c Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Tue, 1 Oct 2024 10:05:20 +0100 Subject: [PATCH] Minor fixes around the Arm testing framework Signed-off-by: Benjamin Klimczak Change-Id: I9c01eff1b4a06e327bfdef82f63c8be5d089c705 --- backends/arm/test/runner_utils.py | 7 +-- backends/arm/test/tester/arm_tester.py | 6 ++- backends/arm/tosa_mapping.py | 60 +++++++++++++------------- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 0a0143e14c6..5ca571f26da 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -457,9 +457,10 @@ def prep_data_for_save( data_np = np.array(data.detach(), order="C").astype(np.float32) if is_quantized: - assert ( - quant_param.node_name in input_name - ), "These quantization params do not match the input tensor name" + assert quant_param.node_name in input_name, ( + f"The quantization params name '{quant_param.node_name}' does not " + f"match the input tensor name '{input_name}'." + ) data_np = ( ((data_np / np.float32(quant_param.scale)) + quant_param.zp) .round() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index eb52f4b2070..053ddc3a8ef 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -150,6 +150,7 @@ def __init__( model: torch.nn.Module, example_inputs: Tuple[torch.Tensor], compile_spec: List[CompileSpec] = None, + tosa_ref_model_path: str | None = None, ): """ Args: @@ -160,7 +161,10 @@ def __init__( # Initiate runner_util intermediate_path = get_intermediate_path(compile_spec) - self.runner_util = RunnerUtil(intermediate_path=intermediate_path) + self.runner_util = RunnerUtil( + intermediate_path=intermediate_path, + tosa_ref_model_path=tosa_ref_model_path, + ) self.compile_spec = compile_spec super().__init__(model, example_inputs) diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py index 0baf3e2ec1b..ec57bd5ce22 100644 --- a/backends/arm/tosa_mapping.py +++ b/backends/arm/tosa_mapping.py @@ -15,37 +15,37 @@ import torch +UNSUPPORTED_DTYPES = ( + torch.float64, + torch.double, + torch.complex64, + torch.cfloat, + torch.complex128, + torch.cdouble, + torch.uint8, + torch.int64, + torch.long, +) + +DTYPE_MAP = { + torch.float32: ts.DType.FP32, + torch.float: ts.DType.FP32, + torch.float16: ts.DType.FP16, + torch.half: ts.DType.FP16, + torch.bfloat16: ts.DType.BF16, + torch.int8: ts.DType.INT8, + torch.int16: ts.DType.INT16, + torch.short: ts.DType.INT16, + torch.int32: ts.DType.INT32, + torch.int: ts.DType.INT32, + torch.bool: ts.DType.BOOL, +} + + def map_dtype(data_type): - unsupported = ( - torch.float64, - torch.double, - torch.complex64, - torch.cfloat, - torch.complex128, - torch.cdouble, - torch.uint8, - torch.int64, - torch.long, - ) - - dmap = { - torch.float32: ts.DType.FP32, - torch.float: ts.DType.FP32, - torch.float16: ts.DType.FP16, - torch.half: ts.DType.FP16, - torch.bfloat16: ts.DType.BF16, - torch.int8: ts.DType.INT8, - torch.int16: ts.DType.INT16, - torch.short: ts.DType.INT16, - torch.int32: ts.DType.INT32, - torch.int: ts.DType.INT32, - torch.bool: ts.DType.BOOL, - } - - assert unsupported.count(data_type) == 0, "Unsupported type" - rtype = dmap.get(data_type) - assert rtype is not None, "Unknown type" - return rtype + assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}" + assert data_type in DTYPE_MAP, f"Unknown type: {data_type}" + return DTYPE_MAP[data_type] # Returns the shape and type of a node