diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index 4eaf1c205cc..4b43b6c9e50 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Tuple +from typing import Any, Callable, Tuple import torch from executorch.backends.arm.quantizer import is_annotated @@ -18,20 +18,25 @@ class SingleOpModel(torch.nn.Module): - def __init__(self, op, example_input, **op_kwargs) -> None: + def __init__( + self, + op: Callable[..., torch.Tensor], + example_input: Tuple[Any, ...], + **op_kwargs: Any, + ) -> None: super().__init__() - self.op = op - self._example_input = example_input - self.op_kwargs = op_kwargs + self.op: Callable[..., torch.Tensor] = op + self._example_input: Tuple[Any, ...] = example_input + self.op_kwargs: dict[str, Any] = dict(op_kwargs) - def forward(self, x): + def forward(self, x: Any) -> torch.Tensor: return self.op(x, **self.op_kwargs) - def example_inputs(self): + def example_inputs(self) -> Tuple[Any, ...]: return self._example_input -def check_annotation(model): +def check_annotation(model: SingleOpModel) -> None: pipeline = TosaPipelineINT[input_t1](model, model.example_inputs(), [], []) pipeline.pop_stage("check_count.exir") pipeline.pop_stage("run_method_and_compare_outputs")