diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index ac2c9cfe065..5d33812df97 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -374,10 +374,69 @@ def evaluate(self) -> dict[str, Any]: return output +class ResNet18Evaluator(GenericModelEvaluator): + REQUIRES_CONFIG = True + + def __init__( + self, + model_name: str, + fp32_model: Module, + int8_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + batch_size: int, + validation_dataset_path: str, + ) -> None: + super().__init__( + model_name, fp32_model, int8_model, example_input, tosa_output_path + ) + self.__batch_size = batch_size + self.__validation_set_path = validation_dataset_path + + @staticmethod + def __load_dataset(directory: str) -> datasets.ImageFolder: + return _load_imagenet_folder(directory) + + @staticmethod + def get_calibrator(training_dataset_path: str) -> DataLoader: + dataset = ResNet18Evaluator.__load_dataset(training_dataset_path) + return _build_calibration_loader(dataset, 1000) + + @classmethod + def from_config( + cls, + model_name: str, + fp32_model: Module, + int8_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + config: dict[str, Any], + ) -> "ResNet18Evaluator": + return cls( + model_name, + fp32_model, + int8_model, + example_input, + tosa_output_path, + batch_size=config["batch_size"], + validation_dataset_path=config["validation_dataset_path"], + ) + + def evaluate(self) -> dict[str, Any]: + dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path) + top1, top5 = GenericModelEvaluator.evaluate_topk( + self.int8_model, dataset, self.__batch_size, topk=5 + ) + output = super().evaluate() + output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5} + return output + + evaluators: dict[str, type[GenericModelEvaluator]] = { "generic": GenericModelEvaluator, "mv2": MobileNetV2Evaluator, "deit_tiny": DeiTTinyEvaluator, + "resnet18": ResNet18Evaluator, } @@ -394,16 +453,12 @@ def evaluator_calibration_data( with config_path.open() as f: config = json.load(f) - if evaluator is MobileNetV2Evaluator: - return evaluator.get_calibrator( - training_dataset_path=config["training_dataset_path"] - ) - if evaluator is DeiTTinyEvaluator: - return evaluator.get_calibrator( - training_dataset_path=config["training_dataset_path"] - ) - else: - raise RuntimeError(f"Unknown evaluator: {evaluator_name}") + # All current evaluators exposing calibration implement a uniform + # static method signature: get_calibrator(training_dataset_path: str) + # so we can call it generically without enumerating classes. + return evaluator.get_calibrator( + training_dataset_path=config["training_dataset_path"] + ) def evaluate_model( diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 1348542de07..9c35e23d5dd 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -492,7 +492,7 @@ def get_args(): required=False, nargs="?", const="generic", - choices=["generic", "mv2", "deit_tiny"], + choices=["generic", "mv2", "deit_tiny", "resnet18"], help="Flag for running evaluation of the model.", ) parser.add_argument( diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 5f9791843aa..45abfd8f89d 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -39,6 +39,7 @@ class Model(str, Enum): Qwen25 = "qwen2_5_1_5b" Phi4Mini = "phi_4_mini" SmolLM2 = "smollm2" + DeiTTiny = "deit_tiny" def __str__(self) -> str: return self.value @@ -87,6 +88,7 @@ def __str__(self) -> str: str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"), str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"), str(Model.SmolLM2): ("smollm2", "SmolLM2Model"), + str(Model.DeiTTiny): ("deit_tiny", "DeiTTinyModel"), } __all__ = [ diff --git a/examples/models/deit_tiny/__init__.py b/examples/models/deit_tiny/__init__.py new file mode 100644 index 00000000000..d43d533e1ab --- /dev/null +++ b/examples/models/deit_tiny/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import DeiTTinyModel + +__all__ = ["DeiTTinyModel"] diff --git a/examples/models/deit_tiny/model.py b/examples/models/deit_tiny/model.py new file mode 100644 index 00000000000..e92167bfbb4 --- /dev/null +++ b/examples/models/deit_tiny/model.py @@ -0,0 +1,42 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torchvision import transforms + +try: + import timm # type: ignore +except ImportError as e: # pragma: no cover + raise RuntimeError( + "timm package is required for builtin 'deit_tiny'. Install timm." + ) from e + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + +from ..model_base import EagerModelBase + + +class DeiTTinyModel(EagerModelBase): + + def __init__(self): # type: ignore[override] + pass + + def get_eager_model(self) -> torch.nn.Module: # type: ignore[override] + logging.info("Loading timm deit_tiny_patch16_224 model") + model = timm.models.deit.deit_tiny_patch16_224(pretrained=False) + model.eval() + logging.info("Loaded timm deit_tiny_patch16_224 model") + return model + + def get_example_inputs(self): # type: ignore[override] + normalize = transforms.Normalize( + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD + ) + return (normalize(torch.rand((1, 3, 224, 224))),) + + +__all__ = ["DeiTTinyModel"]