diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 3a9818929b9..817a0e93033 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -29,6 +29,8 @@ class arm_test_options(Enum): corstone300 = auto() dump_path = auto() date_format = auto() + model_explorer_host = auto() + model_explorer_port = auto() _test_options: dict[arm_test_options, Any] = {} @@ -41,6 +43,18 @@ def pytest_addoption(parser): parser.addoption("--arm_run_corstone300", action="store_true") parser.addoption("--default_dump_path", default=None) parser.addoption("--date_format", default="%d-%b-%H:%M:%S") + parser.addoption( + "--model_explorer_host", + action="store", + default=None, + help="If set, tries to connect to existing model-explorer server rather than starting a new one.", + ) + parser.addoption( + "--model_explorer_port", + action="store", + default=None, + help="Set the port of the model explorer server. If not set, tries ports between 8080 and 8099.", + ) def pytest_configure(config): @@ -62,7 +76,19 @@ def pytest_configure(config): raise RuntimeError( f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory." ) + if config.option.model_explorer_port: + if not str.isdecimal(config.option.model_explorer_port): + raise RuntimeError( + f"--model_explorer_port needs to be an integer, got '{config.option.model_explorer_port}'." + ) + else: + _test_options[arm_test_options.model_explorer_port] = int( + config.option.model_explorer_port + ) _test_options[arm_test_options.date_format] = config.option.date_format + _test_options[arm_test_options.model_explorer_host] = ( + config.option.model_explorer_host + ) logging.basicConfig(level=logging.INFO, stream=sys.stdout) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 14a9d1df41d..d34833525bc 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import logging - from collections import Counter from pprint import pformat from typing import Any, Iterable, List, Literal, Optional, Tuple, Union @@ -35,6 +34,7 @@ dbg_tosa_fb_to_json, RunnerUtil, ) +from executorch.backends.arm.test.visualize import visualize from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.backends.xnnpack.test.tester import Tester @@ -47,6 +47,8 @@ from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph +from typing_extensions import Self + logger = logging.getLogger(__name__) @@ -473,6 +475,22 @@ def dump_dtype_distribution( _dump_str(to_print, path_to_dump) return self + def visualize(self) -> Self: + exported_program = self._get_exported_program() + visualize(exported_program) + return self + + def _get_exported_program(self): + match self.cur: + case "Export": + return self.get_artifact() + case "ToEdge" | "Partition": + return self.get_artifact().exported_program() + case _: + raise RuntimeError( + "Can only get the exported program for the Export, ToEdge, or Partition stage." + ) + @staticmethod def _calculate_reference_output( module: Union[torch.fx.GraphModule, torch.nn.Module], inputs diff --git a/backends/arm/test/visualize.py b/backends/arm/test/visualize.py new file mode 100644 index 00000000000..ac99cd0dd44 --- /dev/null +++ b/backends/arm/test/visualize.py @@ -0,0 +1,62 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from executorch.backends.arm.test.common import arm_test_options, get_option +from torch.export import ExportedProgram + +logger = logging.getLogger(__name__) +_model_explorer_installed = False + +try: + # pyre-ignore[21]: We keep track of whether import succeeded manually. + from model_explorer import config, visualize_from_config, visualize_pytorch + + _model_explorer_installed = True +except ImportError: + logger.warning("model-explorer is not installed, can't visualize models.") + + +def is_model_explorer_installed() -> bool: + return _model_explorer_installed + + +def get_pytest_option_host() -> str | None: + host = get_option(arm_test_options.model_explorer_host) + return str(host) if host else None + + +def get_pytest_option_port() -> int | None: + port = get_option(arm_test_options.model_explorer_port) + return int(port) if port else None + + +def visualize( + exported_program: ExportedProgram, + host: Optional[str] = None, + port: Optional[int] = None, +): + """Attempt visualizing exported_program using model-explorer.""" + + host = host if host else get_pytest_option_host() + port = port if port else get_pytest_option_port() + + if not is_model_explorer_installed(): + logger.warning("Can't visualize model since model-explorer is not installed.") + return + + # If a host is provided, we attempt connecting to an already running server. + # Note that this needs a modified model-explorer + if host: + explorer_config = ( + config() + .add_model_from_pytorch("ExportedProgram", exported_program) + .set_reuse_server(server_host=host, server_port=port) + ) + visualize_from_config(explorer_config) + else: + visualize_pytorch(exported_program)