From bf654f8f8bca83362ea44a19fdf8f4cadd8fbcf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Tue, 15 Apr 2025 15:25:05 +0200 Subject: [PATCH] Arm backend: Add test for DeiT Tiny for TOSA BI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test case for the DeiT Tiny model for the TOSA BI profile. At this time the output of the model differs from the reference implementation by a mean absolute error of around 2.5, which is too high. An internal ticket has been raised to resolve this issue. Signed-off-by: Martin Lindström Change-Id: I9f3223068d35fa6f3e485e0fb12ad49bb8b3d534 --- backends/arm/test/models/test_deit_tiny_arm.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/models/test_deit_tiny_arm.py b/backends/arm/test/models/test_deit_tiny_arm.py index b19eb811bb1..f2269e3bed1 100644 --- a/backends/arm/test/models/test_deit_tiny_arm.py +++ b/backends/arm/test/models/test_deit_tiny_arm.py @@ -11,7 +11,10 @@ import torch -from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineMI +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from torchvision import transforms @@ -42,3 +45,16 @@ def test_deit_tiny_tosa_MI(): qtol=1, ) pipeline.run() + + +def test_deit_tiny_tosa_BI(): + pipeline = TosaPipelineBI[input_t]( + deit_tiny, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=3.0, # This needs to go down: MLETORCH-956 + qtol=1, + ) + pipeline.run()