Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions backends/arm/util/arm_model_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,69 @@ def evaluate(self) -> dict[str, Any]:
return output


class ResNet18Evaluator(GenericModelEvaluator):
REQUIRES_CONFIG = True

def __init__(
self,
model_name: str,
fp32_model: Module,
int8_model: Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: str | None,
batch_size: int,
validation_dataset_path: str,
) -> None:
super().__init__(
model_name, fp32_model, int8_model, example_input, tosa_output_path
)
self.__batch_size = batch_size
self.__validation_set_path = validation_dataset_path

@staticmethod
def __load_dataset(directory: str) -> datasets.ImageFolder:
return _load_imagenet_folder(directory)

@staticmethod
def get_calibrator(training_dataset_path: str) -> DataLoader:
dataset = ResNet18Evaluator.__load_dataset(training_dataset_path)
return _build_calibration_loader(dataset, 1000)

@classmethod
def from_config(
cls,
model_name: str,
fp32_model: Module,
int8_model: Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: str | None,
config: dict[str, Any],
) -> "ResNet18Evaluator":
return cls(
model_name,
fp32_model,
int8_model,
example_input,
tosa_output_path,
batch_size=config["batch_size"],
validation_dataset_path=config["validation_dataset_path"],
)

def evaluate(self) -> dict[str, Any]:
dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path)
top1, top5 = GenericModelEvaluator.evaluate_topk(
self.int8_model, dataset, self.__batch_size, topk=5
)
output = super().evaluate()
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
return output


evaluators: dict[str, type[GenericModelEvaluator]] = {
"generic": GenericModelEvaluator,
"mv2": MobileNetV2Evaluator,
"deit_tiny": DeiTTinyEvaluator,
"resnet18": ResNet18Evaluator,
}


Expand All @@ -394,16 +453,12 @@ def evaluator_calibration_data(
with config_path.open() as f:
config = json.load(f)

if evaluator is MobileNetV2Evaluator:
return evaluator.get_calibrator(
training_dataset_path=config["training_dataset_path"]
)
if evaluator is DeiTTinyEvaluator:
return evaluator.get_calibrator(
training_dataset_path=config["training_dataset_path"]
)
else:
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
# All current evaluators exposing calibration implement a uniform
# static method signature: get_calibrator(training_dataset_path: str)
# so we can call it generically without enumerating classes.
return evaluator.get_calibrator(
training_dataset_path=config["training_dataset_path"]
)


def evaluate_model(
Expand Down
2 changes: 1 addition & 1 deletion examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def get_args():
required=False,
nargs="?",
const="generic",
choices=["generic", "mv2", "deit_tiny"],
choices=["generic", "mv2", "deit_tiny", "resnet18"],
help="Flag for running evaluation of the model.",
)
parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Model(str, Enum):
Qwen25 = "qwen2_5_1_5b"
Phi4Mini = "phi_4_mini"
SmolLM2 = "smollm2"
DeiTTiny = "deit_tiny"

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -87,6 +88,7 @@ def __str__(self) -> str:
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"),
str(Model.SmolLM2): ("smollm2", "SmolLM2Model"),
str(Model.DeiTTiny): ("deit_tiny", "DeiTTinyModel"),
}

__all__ = [
Expand Down
8 changes: 8 additions & 0 deletions examples/models/deit_tiny/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import DeiTTinyModel

__all__ = ["DeiTTinyModel"]
42 changes: 42 additions & 0 deletions examples/models/deit_tiny/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from torchvision import transforms

try:
import timm # type: ignore
except ImportError as e: # pragma: no cover
raise RuntimeError(
"timm package is required for builtin 'deit_tiny'. Install timm."
) from e

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

from ..model_base import EagerModelBase


class DeiTTinyModel(EagerModelBase):

def __init__(self): # type: ignore[override]
pass

def get_eager_model(self) -> torch.nn.Module: # type: ignore[override]
logging.info("Loading timm deit_tiny_patch16_224 model")
model = timm.models.deit.deit_tiny_patch16_224(pretrained=False)
model.eval()
logging.info("Loaded timm deit_tiny_patch16_224 model")
return model

def get_example_inputs(self): # type: ignore[override]
normalize = transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
)
return (normalize(torch.rand((1, 3, 224, 224))),)


__all__ = ["DeiTTinyModel"]
Loading