From 6e0b6cb276cca3708e3ce17f1d8c6a3a050a532a Mon Sep 17 00:00:00 2001 From: Juntian Liu Date: Tue, 3 Jun 2025 11:10:16 -0700 Subject: [PATCH] Reusable test framework the Inspector tests (#11314) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11314 This Diff introduces a reusable test framework. The inspector_test_utils.py file provides methods to instantiate the model, retrieve expected outputs, and assert the correctness of actual outputs in a extensible way. Also, update the intermediate_output_capturer_test to take advantage of the reusable framework and make the intermediate_output_capturer_test more extensible by using reusable setup method, making it easier to add new models and tests. Reviewed By: Gasoonjia Differential Revision: D75803288 --- devtools/inspector/tests/TARGETS | 13 ++ .../inspector/tests/inspector_test_utils.py | 118 ++++++++++++++ .../intermediate_output_capturer_test.py | 151 ++++++------------ 3 files changed, 177 insertions(+), 105 deletions(-) create mode 100644 devtools/inspector/tests/inspector_test_utils.py diff --git a/devtools/inspector/tests/TARGETS b/devtools/inspector/tests/TARGETS index 78450dc5fe2..b5fbeda215b 100644 --- a/devtools/inspector/tests/TARGETS +++ b/devtools/inspector/tests/TARGETS @@ -1,4 +1,5 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("executorch") @@ -13,6 +14,7 @@ python_unittest( "//executorch/devtools/inspector:inspector", "//executorch/devtools/inspector:lib", "//executorch/exir:lib", + "//executorch/devtools/inspector/tests:inspector_test_utils", ], ) @@ -48,5 +50,16 @@ python_unittest( "//executorch/devtools/inspector:lib", "//executorch/devtools/inspector:intermediate_output_capturer", "//executorch/exir:lib", + "//executorch/devtools/inspector/tests:inspector_test_utils", + ], +) + +python_library( + name = "inspector_test_utils", + srcs = [ + "inspector_test_utils.py", + ], + deps = [ + "//caffe2:torch", ], ) diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py new file mode 100644 index 00000000000..b9d4b1882b8 --- /dev/null +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvlLinearModel(nn.Module): + """ + A neural network model with a convolutional layer followed by a linear layer. + """ + + def __init__(self): + super(ConvlLinearModel, self).__init__() + self.conv_layer = nn.Conv2d( + in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 + ) + self.conv_layer.weight = nn.Parameter( + torch.tensor([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]) + ) + self.conv_layer.bias = nn.Parameter(torch.tensor([0.0])) + + self.linear_layer = nn.Linear(in_features=4, out_features=2) + self.linear_layer.weight = nn.Parameter( + torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) + ) + self.linear_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0])) + self.additional_bias = nn.Parameter( + torch.tensor([0.5, -0.5]), requires_grad=False + ) + self.scale_factor = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False) + + def forward(self, x): + x = self.conv_layer(x) + x = x.view(x.size(0), -1) + x = self.linear_layer(x) + x = x + self.additional_bias + x = x - 0.1 + x = x * self.scale_factor + x = x / (self.scale_factor + 1.0) + x = F.relu(x) + x = torch.sigmoid(x) + output1, output2 = torch.split(x, 1, dim=1) + return output1, output2 + + @staticmethod + def get_input(): + """ + Returns the pre-defined input tensor for this model. + """ + return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True) + + @staticmethod + def get_expected_intermediate_outputs(): + """ + Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input. + """ + return { + (10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), + (11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), + (12,): torch.tensor( + [ + [0.1000, 0.5000], + [0.2000, 0.6000], + [0.3000, 0.7000], + [0.4000, 0.8000], + ] + ), + (13,): torch.tensor([[5.0000, 14.1200]]), + (14,): torch.tensor([[5.5000, 13.6200]]), + (15,): torch.tensor([[5.4000, 13.5200]]), + (16,): torch.tensor([[10.8000, 6.7600]]), + (17,): torch.tensor([3.0000, 1.5000]), + (18,): torch.tensor([[3.6000, 4.5067]]), + (19,): torch.tensor([[3.6000, 4.5067]]), + (20,): torch.tensor([[0.9734, 0.9891]]), + (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], + } + + +# Global model registry +model_registry = { + "ConvLinearModel": ConvlLinearModel, + # Add new models here +} + + +def check_if_final_outputs_match(model_name, actual_outputs_with_handles): + """ + Checks if the actual outputs match the expected outputs for the specified model. + Returns True if all outputs match, otherwise returns False. + """ + model_instance = model_registry[model_name] + expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs() + if len(actual_outputs_with_handles) != len(expected_outputs_with_handles): + return False + for debug_handle, expected_output in expected_outputs_with_handles.items(): + actual_output = actual_outputs_with_handles.get(debug_handle) + if actual_output is None: + return False + if isinstance(expected_output, list): + if not isinstance(actual_output, list): + return False + if len(actual_output) != len(expected_output): + return False + for actual, expected in zip(actual_output, expected_output): + if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-5): + return False + else: + if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5): + return False + return True diff --git a/devtools/inspector/tests/intermediate_output_capturer_test.py b/devtools/inspector/tests/intermediate_output_capturer_test.py index 7ad673c7cfe..3c8d2487e70 100644 --- a/devtools/inspector/tests/intermediate_output_capturer_test.py +++ b/devtools/inspector/tests/intermediate_output_capturer_test.py @@ -6,127 +6,68 @@ # pyre-unsafe - import unittest import torch -import torch.nn as nn -import torch.nn.functional as F from executorch.devtools.inspector._intermediate_output_capturer import ( IntermediateOutputCapturer, ) - +from executorch.devtools.inspector.tests.inspector_test_utils import ( + check_if_final_outputs_match, + model_registry, +) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from torch.export import export, ExportedProgram from torch.fx import GraphModule class TestIntermediateOutputCapturer(unittest.TestCase): - @classmethod - def setUpClass(cls): - class TestModule(nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.conv = nn.Conv2d( - in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 - ) - self.conv.weight = nn.Parameter( - torch.tensor( - [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]] - ) - ) - self.conv.bias = nn.Parameter(torch.tensor([0.0])) - - self.linear = nn.Linear(in_features=4, out_features=2) - self.linear.weight = nn.Parameter( - torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) - ) - self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0])) - self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False) - self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False) - - def forward(self, x): - x = self.conv(x) - x = x.view(x.size(0), -1) - x = self.linear(x) - x = x + self.bias - x = x - 0.1 - x = x * self.scale - x = x / (self.scale + 1.0) - x = F.relu(x) - x = torch.sigmoid(x) - x1, x2 = torch.split(x, 1, dim=1) - return x1, x2 - - cls.model = TestModule() - cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True) - cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True) - cls.edge_program_manager: EdgeProgramManager = to_edge( - cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True) + def _set_up_model(self, model_name): + model = model_registry[model_name]() + input_tensor = model.get_input() + aten_model: ExportedProgram = export(model, (input_tensor,), strict=True) + edge_program_manager: EdgeProgramManager = to_edge( + aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True) ) - cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[ + graph_module: GraphModule = edge_program_manager._edge_programs[ "forward" ].module() - cls.capturer = IntermediateOutputCapturer(cls.graph_module) - cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input) - - def test_keying_with_debug_handle_tuple(self): - for key in self.intermediate_outputs.keys(): - self.assertIsInstance(key, tuple) - - def test_tensor_cloning_and_detaching(self): - for output in self.intermediate_outputs.values(): - if isinstance(output, torch.Tensor): - self.assertFalse(output.requires_grad) - self.assertTrue(output.is_leaf) - - def test_placeholder_nodes_are_skipped(self): - for node in self.graph_module.graph.nodes: - if node.op == "placeholder": - self.assertNotIn( - node.meta.get("debug_handle"), self.intermediate_outputs + capturer = IntermediateOutputCapturer(graph_module) + intermediate_outputs = capturer.run_and_capture(input_tensor) + return input_tensor, graph_module, capturer, intermediate_outputs + + def test_models(self): + available_models = list(model_registry.keys()) + for model_name in available_models: + with self.subTest(model=model_name): + input_tensor, graph_module, capturer, intermediate_outputs = ( + self._set_up_model(model_name) ) - def test_multiple_outputs_capture(self): - outputs = self.capturer.run_and_capture(self.input) - for output in outputs.values(): - if isinstance(output, tuple): - self.assertEqual(len(output), 2) - for part in output: - self.assertIsInstance(part, torch.Tensor) - - def test_capture_correct_outputs(self): - expected_outputs_with_handles = { - (10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), - (11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), - (12,): torch.tensor( - [[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]] - ), - (13,): torch.tensor([[5.0000, 14.1200]]), - (14,): torch.tensor([[5.5000, 13.6200]]), - (15,): torch.tensor([[5.4000, 13.5200]]), - (16,): torch.tensor([[10.8000, 6.7600]]), - (17,): torch.tensor([3.0000, 1.5000]), - (18,): torch.tensor([[3.6000, 4.5067]]), - (19,): torch.tensor([[3.6000, 4.5067]]), - (20,): torch.tensor([[0.9734, 0.9891]]), - (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], - } - self.assertEqual( - len(self.intermediate_outputs), len(expected_outputs_with_handles) - ) - - for debug_handle, expected_output in expected_outputs_with_handles.items(): - actual_output = self.intermediate_outputs.get(debug_handle) - self.assertIsNotNone(actual_output) - if isinstance(expected_output, list): - self.assertIsInstance(actual_output, list) - self.assertEqual(len(actual_output), len(expected_output)) - for actual, expected in zip(actual_output, expected_output): - self.assertTrue( - torch.allclose(actual, expected, rtol=1e-4, atol=1e-5) - ) - else: + # Test keying with debug handle tuple + for key in intermediate_outputs.keys(): + self.assertIsInstance(key, tuple) + + # Test tensor cloning and detaching + for output in intermediate_outputs.values(): + if isinstance(output, torch.Tensor): + self.assertFalse(output.requires_grad) + self.assertTrue(output.is_leaf) + + # Test placeholder nodes are skipped + for node in graph_module.graph.nodes: + if node.op == "placeholder": + self.assertNotIn(node.meta.get("debug_handle"), node.meta) + + # Test multiple outputs capture + outputs = capturer.run_and_capture(input_tensor) + for output in outputs.values(): + if isinstance(output, tuple): + self.assertEqual(len(output), 2) + for part in output: + self.assertIsInstance(part, torch.Tensor) + + # Test capture correct outputs self.assertTrue( - torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5) + check_if_final_outputs_match(model_name, intermediate_outputs) )