diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 1b6f03512bd..019b7a4644f 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -627,6 +628,15 @@ def check_node_count(self, input: Dict[Any, int]): return self + def visualize( + self, reuse_server: bool = True, stage: Optional[str] = None, **kwargs + ): + # import here to avoid importing model_explorer when it is not needed which is most of the time. + from executorch.devtools.visualization import visualize + + visualize(self.get_artifact(stage), reuse_server=reuse_server, **kwargs) + return self + def run_method_and_compare_outputs( self, stage: Optional[str] = None, diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py new file mode 100644 index 00000000000..645cc5d5378 --- /dev/null +++ b/devtools/visualization/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from executorch.devtools.visualization.visualization_utils import ( # noqa: F401 + ModelExplorerServer, + SingletonModelExplorerServer, + visualize, +) diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py new file mode 100644 index 00000000000..a2ee4c60505 --- /dev/null +++ b/devtools/visualization/visualization_utils.py @@ -0,0 +1,119 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import subprocess +import time + +from executorch.exir import EdgeProgramManager, ExecutorchProgramManager +from model_explorer import config, consts, visualize_from_config # type: ignore +from torch.export.exported_program import ExportedProgram + + +class SingletonModelExplorerServer: + """Singleton context manager for starting a model-explorer server. + If multiple ModelExplorerServer contexts are nested, a single + server is still used. + """ + + server: None | subprocess.Popen = None + num_open: int = 0 + wait_after_start = 2.0 + + def __init__(self, open_in_browser: bool = True, port: int | None = None): + if SingletonModelExplorerServer.server is None: + command = ["model-explorer"] + if not open_in_browser: + command.append("--no_open_in_browser") + if port is not None: + command.append("--port") + command.append(str(port)) + SingletonModelExplorerServer.server = subprocess.Popen(command) + + def __enter__(self): + SingletonModelExplorerServer.num_open = ( + SingletonModelExplorerServer.num_open + 1 + ) + time.sleep(SingletonModelExplorerServer.wait_after_start) + return self + + def __exit__(self, type, value, traceback): + SingletonModelExplorerServer.num_open = ( + SingletonModelExplorerServer.num_open - 1 + ) + if SingletonModelExplorerServer.num_open == 0: + if SingletonModelExplorerServer.server is not None: + SingletonModelExplorerServer.server.kill() + try: + SingletonModelExplorerServer.server.wait( + SingletonModelExplorerServer.wait_after_start + ) + except subprocess.TimeoutExpired: + SingletonModelExplorerServer.server.terminate() + SingletonModelExplorerServer.server = None + + +class ModelExplorerServer: + """Context manager for starting a model-explorer server.""" + + wait_after_start = 2.0 + + def __init__(self, open_in_browser: bool = True, port: int | None = None): + command = ["model-explorer"] + if not open_in_browser: + command.append("--no_open_in_browser") + if port is not None: + command.append("--port") + command.append(str(port)) + self.server = subprocess.Popen(command) + + def __enter__(self): + time.sleep(self.wait_after_start) + + def __exit__(self, type, value, traceback): + self.server.kill() + try: + self.server.wait(self.wait_after_start) + except subprocess.TimeoutExpired: + self.server.terminate() + + +def _get_exported_program( + visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, +) -> ExportedProgram: + if isinstance(visualizable, ExportedProgram): + return visualizable + if isinstance(visualizable, (EdgeProgramManager, ExecutorchProgramManager)): + return visualizable.exported_program() + raise RuntimeError(f"Cannot get ExportedProgram from {visualizable}") + + +def visualize( + visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, + reuse_server: bool = True, + no_open_in_browser: bool = False, + **kwargs, +): + """Wraps the visualize_from_config call from model_explorer. + For convenicence, figures out how to find the exported_program + from EdgeProgramManager and ExecutorchProgramManager for you. + + See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models + for full documentation. + """ + cur_config = config() + settings = consts.DEFAULT_SETTINGS + cur_config.add_model_from_pytorch( + "Executorch", + exported_program=_get_exported_program(visualizable), + settings=settings, + ) + if reuse_server: + cur_config.set_reuse_server() + visualize_from_config( + cur_config, + no_open_in_browser=no_open_in_browser, + **kwargs, + ) diff --git a/devtools/visualization/visualization_utils_test.py b/devtools/visualization/visualization_utils_test.py new file mode 100644 index 00000000000..89781ab4f43 --- /dev/null +++ b/devtools/visualization/visualization_utils_test.py @@ -0,0 +1,153 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import time + +import pytest +import torch +from executorch.backends.xnnpack.test.tester import Tester + +from executorch.devtools.visualization import ( + ModelExplorerServer, + SingletonModelExplorerServer, + visualization_utils, + visualize, +) +from executorch.exir import ExportedProgram +from model_explorer.config import ModelExplorerConfig # type: ignore + + +@pytest.fixture +def server(): + """Mock relevant calls in visualization.visualize and check that parameters have their expected value.""" + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context(): + _called_reuse_server = False + + def mock_set_reuse_server(self): + nonlocal _called_reuse_server + _called_reuse_server = True + + def mock_add_model_from_pytorch(self, name, exported_program, settings): + assert isinstance(exported_program, ExportedProgram) + + def mock_visualize_from_config(cur_config, no_open_in_browser): + pass + + monkeypatch.setattr( + ModelExplorerConfig, "set_reuse_server", mock_set_reuse_server + ) + monkeypatch.setattr( + ModelExplorerConfig, "add_model_from_pytorch", mock_add_model_from_pytorch + ) + monkeypatch.setattr( + visualization_utils, "visualize_from_config", mock_visualize_from_config + ) + yield monkeypatch.context + assert _called_reuse_server, "Did not call reuse_server" + + +class Linear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int = 3, + bias: bool = True, + ): + super().__init__() + self.inputs = (torch.randn(5, 10, 25, in_features),) + self.fc = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + + def get_inputs(self) -> tuple[torch.Tensor]: + return self.inputs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + +def test_visualize_manual_export(server): + with server(): + model = Linear(20, 30) + exported_program = torch.export.export(model, model.get_inputs()) + visualize(exported_program) + time.sleep(3.0) + + +def test_visualize_exported_program(server): + with server(): + model = Linear(20, 30) + ( + Tester( + model, + example_inputs=model.get_inputs(), + ) + .export() + .visualize() + ) + + +def test_visualize_to_edge(server): + with server(): + model = Linear(20, 30) + ( + Tester( + model, + example_inputs=model.get_inputs(), + ) + .export() + .to_edge() + .visualize() + ) + + +def test_visualize_partition(server): + with server(): + model = Linear(20, 30) + ( + Tester( + model, + example_inputs=model.get_inputs(), + ) + .export() + .to_edge() + .partition() + .visualize() + ) + + +def test_visualize_to_executorch(server): + with server(): + model = Linear(20, 30) + ( + Tester( + model, + example_inputs=model.get_inputs(), + ) + .export() + .to_edge() + .partition() + .to_executorch() + .visualize() + ) + + +if __name__ == "__main__": + """A test to run locally to make sure that the web browser opens up + automatically as intended. + """ + + test_visualize_manual_export(ModelExplorerServer) + + with SingletonModelExplorerServer(): + test_visualize_manual_export(SingletonModelExplorerServer) + test_visualize_exported_program(SingletonModelExplorerServer) + test_visualize_to_edge(SingletonModelExplorerServer) + test_visualize_partition(SingletonModelExplorerServer) + test_visualize_to_executorch(SingletonModelExplorerServer) diff --git a/install_requirements.py b/install_requirements.py index 3fca161a2c3..60b61271ed5 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-25 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -132,7 +133,7 @@ def python_is_compatible(): # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION = "dev20241218" +NIGHTLY_VERSION = "dev20250104" # The pip repository that hosts nightly torch packages. TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu" @@ -169,6 +170,7 @@ def python_is_compatible(): "tomli", # Imported by extract_sources.py when using python < 3.11. "wheel", # For building the pip package archive. "zstd", # Imported by resolve_buck.py. + "ai-edge-model-explorer>=0.1.16", # For visualizing ExportedPrograms ] # Assemble the list of requirements to actually install.