From a97ce2d90277666a52c71659fb023e84f18a6edd Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 1 Feb 2024 21:02:04 -0800 Subject: [PATCH] support non-persistent buffers (#1817) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Basic support for non-persistent buffers, which are buffers that do not show up in the state dict. One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict. This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them. As a side effect, this diff tightened up quite a few sloppy behaviors around state dict handling: - Tensor attributes were getting promoted to be buffers—bad! - Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad! This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests. #buildmore Differential Revision: D53340041 --- backends/xnnpack/test/ops/add.py | 2 +- backends/xnnpack/test/ops/linear.py | 4 +-- .../test/test_xnnpack_utils_classes.py | 9 ++++-- backends/xnnpack/xnnpack_preprocess.py | 1 + exir/backend/backend_api.py | 3 +- exir/lowered_backend_module.py | 23 ++++++++++--- exir/passes/constant_prop_pass.py | 1 + exir/program/_program.py | 8 ++++- exir/serde/export_serialize.py | 1 + exir/tests/test_passes.py | 6 +++- profiler/test/test_profiler_e2e.py | 4 +-- sdk/bundled_program/util/TARGETS | 1 - sdk/bundled_program/util/test_util.py | 32 ++++++++++++++++--- test/end2end/exported_module.py | 4 +-- 14 files changed, 77 insertions(+), 22 deletions(-) diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index 0027d04c5f9..48eb836f322 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -33,7 +33,7 @@ def forward(self, x): class AddConstant(torch.nn.Module): def __init__(self, constant): super().__init__() - self._constant = constant + self.register_buffer("_constant", constant, persistent=False) def forward(self, x): out1 = x + self._constant diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index d7a9896d08d..b1474d505d2 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -40,8 +40,8 @@ def test_fp32_addmm(self): class AddMMModule(torch.nn.Module): def __init__(self, in_size, out_size): super().__init__() - self.mat = torch.randn(out_size, in_size) - self.bias = torch.randn(1, out_size) + self.mat = torch.nn.Parameter(torch.randn(out_size, in_size)) + self.bias = torch.nn.Parameter(torch.randn(1, out_size)) def forward(self, x): return torch.addmm(self.bias, x, torch.transpose(self.mat, 0, 1)) diff --git a/backends/xnnpack/test/test_xnnpack_utils_classes.py b/backends/xnnpack/test/test_xnnpack_utils_classes.py index 679eb0021d4..50a1914a56b 100644 --- a/backends/xnnpack/test/test_xnnpack_utils_classes.py +++ b/backends/xnnpack/test/test_xnnpack_utils_classes.py @@ -18,10 +18,12 @@ def __init__(self, num_sequences, ops_per_sequence): super().__init__() self.num_ops = num_sequences * ops_per_sequence self.num_sequences = num_sequences - self.op_sequence = [[] for _ in range(num_sequences)] - for seq in range(num_sequences): + + self.op_sequence = torch.nn.ModuleList() + for _ in range(num_sequences): + inner = torch.nn.ModuleList() for _ in range(ops_per_sequence): - self.op_sequence[seq].append( + inner.append( torch.nn.Conv2d( in_channels=1, out_channels=1, @@ -30,6 +32,7 @@ def __init__(self, num_sequences, ops_per_sequence): bias=False, ) ) + self.op_sequence.append(inner) def forward(self, x): for seq in self.op_sequence: diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index efcfff9d0e7..4e52cd4c59b 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -210,6 +210,7 @@ def preprocess( verifier=EXIREdgeDialectVerifier( check_edge_ops=False, enable=False, class_only=True ), + constants=ep.constants, ) # XNNPACK Delegate Specific Passes diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index ab354c3cf1b..4d85a73d73e 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -350,7 +350,7 @@ def to_backend( # TODO(angelayi): Update this signature in a less manual way (maybe through # retracing) - new_signature, new_state_dict = _get_new_signature( + new_signature, new_state_dict, new_constants = _get_new_signature( edge_program, tagged_graph_module ) return ExportedProgram( @@ -362,4 +362,5 @@ def to_backend( module_call_graph=copy.deepcopy(edge_program.module_call_graph), example_inputs=None, verifier=edge_program.verifier, + constants=new_constants, ) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 58e4356a199..b0279d5e752 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -388,7 +388,11 @@ def arrange_graph_placeholders( # TODO Don't regenerate new signature manually. def _get_new_signature( original_program: ExportedProgram, gm: torch.fx.GraphModule -) -> Tuple[ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]]]: +) -> Tuple[ + ExportGraphSignature, + Dict[str, Union[torch.Tensor, torch.nn.Parameter]], + Dict[str, Union[torch.Tensor, torch.ScriptObject]], +]: old_signature = original_program.graph_signature input_specs = [] @@ -397,6 +401,9 @@ def _get_new_signature( input_specs=input_specs, output_specs=output_specs ) new_state_dict = {} + new_constants = {} + + non_persistent_buffers = set(old_signature.non_persistent_buffers) for node in gm.graph.nodes: if node.op == "placeholder": @@ -417,17 +424,24 @@ def _get_new_signature( ] elif node.name in old_signature.inputs_to_buffers: buffer_name = old_signature.inputs_to_buffers[node.name] + persistent = buffer_name not in non_persistent_buffers # add buffer to graph signature input_specs.append( InputSpec( kind=InputKind.BUFFER, arg=TensorArgument(name=node.name), target=buffer_name, + persistent=persistent, ) ) # add param to new_state_dict - new_state_dict[buffer_name] = original_program.state_dict[buffer_name] + if persistent: + new_state_dict[buffer_name] = original_program.state_dict[ + buffer_name + ] + else: + new_constants[buffer_name] = original_program.constants[buffer_name] else: # not param or buffer then user input input_specs.append( @@ -449,7 +463,7 @@ def _get_new_signature( ) ) - return new_signature, new_state_dict + return new_signature, new_state_dict, new_constants def create_exported_program_from_submodule( @@ -472,7 +486,7 @@ def create_exported_program_from_submodule( submodule = arrange_graph_placeholders(submodule, owning_program) # Get updated graph signature - subgraph_signature, subgraph_state_dict = _get_new_signature( + subgraph_signature, subgraph_state_dict, subgraph_constants = _get_new_signature( owning_program, submodule ) @@ -484,6 +498,7 @@ def create_exported_program_from_submodule( range_constraints=copy.deepcopy(owning_program.range_constraints), module_call_graph=[], verifier=owning_program.verifier, + constants=subgraph_constants, ) diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index ee4bfd8c4ff..9617951ce44 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -108,6 +108,7 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: kind=InputKind.BUFFER, arg=TensorArgument(name=const_placeholder_node.name), target=prop_constant_tensor_fqn, + persistent=True, ) prop_constant_data.append(prop_constant_node_input_spec) buffers.append(prop_constant_tensor_fqn) diff --git a/exir/program/_program.py b/exir/program/_program.py index 343fcbf761f..92e3d2bea12 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -78,7 +78,12 @@ def _get_updated_graph_signature( else type(old_input_spec.arg)(node.name) ) new_input_specs.append( - InputSpec(old_input_spec.kind, arg, old_input_spec.target) + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + persistent=old_input_spec.persistent, + ) ) i += 1 @@ -196,6 +201,7 @@ def lift_constant_tensor_pass(ep): kind=InputKind.BUFFER, arg=TensorArgument(name=const_placeholder_node.name), target=constant_tensor_fqn, + persistent=True, ) ) buffers.append(constant_tensor_fqn) diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index e01e86db2e6..8cf12fb995d 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -739,6 +739,7 @@ def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: buffer=InputToBufferSpec( arg=TensorArgument(name=spec.arg.name), buffer_name=spec.target, # pyre-ignore + persistent=spec.persistent, # pyre-ignore ) ) elif spec.kind == ep.InputKind.CONSTANT_TENSOR: diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 65f3b11c8d7..d70d952d70e 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -617,8 +617,12 @@ def test_compile_fix_broken_ops(self) -> None: model: torch.nn.Linear = torch.nn.Linear(5, 5) class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = model + def forward(self, inp: torch.Tensor) -> torch.Tensor: - return model(inp) + return self.model(inp) f = Foo() diff --git a/profiler/test/test_profiler_e2e.py b/profiler/test/test_profiler_e2e.py index b8421efb08d..f5df82176ee 100644 --- a/profiler/test/test_profiler_e2e.py +++ b/profiler/test/test_profiler_e2e.py @@ -34,8 +34,8 @@ class Module(torch.nn.Module): def __init__(self): super().__init__() - self.a = 3 * torch.ones(2, 2, dtype=torch.float) - self.b = 2 * torch.ones(2, 2, dtype=torch.float) + self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float)) + self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float)) def forward(self, x): a = torch.mul(self.a, x) diff --git a/sdk/bundled_program/util/TARGETS b/sdk/bundled_program/util/TARGETS index 231ba5cb26d..17d19dfb29a 100644 --- a/sdk/bundled_program/util/TARGETS +++ b/sdk/bundled_program/util/TARGETS @@ -11,7 +11,6 @@ python_library( deps = [ "//caffe2:torch", "//executorch/exir:lib", - "//executorch/exir:schema", "//executorch/sdk/bundled_program:config", ], ) diff --git a/sdk/bundled_program/util/test_util.py b/sdk/bundled_program/util/test_util.py index 3fb975772a0..c9277964860 100644 --- a/sdk/bundled_program/util/test_util.py +++ b/sdk/bundled_program/util/test_util.py @@ -7,7 +7,7 @@ # pyre-strict import random import string -from typing import List, Tuple +from typing import Callable, List, Tuple import torch @@ -19,6 +19,7 @@ MethodTestSuite, ) from torch.export import export, WrapperModule +from torch.export.unflatten import _assign_attr, _AttrKind # A hacky integer to deal with a mismatch between execution plan and complier. # @@ -45,8 +46,8 @@ class SampleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32) - self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32) + self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32)) + self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32)) self.method_names = ["encode", "decode"] def encode( @@ -228,6 +229,28 @@ def get_random_test_suites_with_eager_model( return inputs_per_program, method_test_suites +class StatefulWrapperModule(torch.nn.Module): + """A version of wrapper module that preserves parameters/buffers. + + Use this if you are planning to wrap a non-forward method on an existing + module. + """ + + def __init__(self, base_mod, method) -> None: # pyre-ignore + super().__init__() + state_dict = base_mod.state_dict() + for name, value in base_mod.named_parameters(): + _assign_attr(value, self, name, _AttrKind.PARAMETER) + for name, value in base_mod.named_buffers(): + _assign_attr( + value, self, name, _AttrKind.BUFFER, persistent=name in state_dict + ) + self.fn = method # pyre-ignore + + def forward(self, *args, **kwargs): # pyre-ignore + return self.fn(*args, **kwargs) + + def get_common_executorch_program() -> Tuple[ ExecutorchProgramManager, List[MethodTestSuite] ]: @@ -246,7 +269,8 @@ def get_common_executorch_program() -> Tuple[ # Trace to FX Graph and emit the program method_graphs = { m_name: export( - WrapperModule(getattr(eager_model, m_name)), capture_inputs[m_name] + StatefulWrapperModule(eager_model, getattr(eager_model, m_name)), + capture_inputs[m_name], ) for m_name in eager_model.method_names } diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index 889cc5fdfe8..40e729e9da7 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -155,9 +155,9 @@ def __init__(self, method): # These cleanup passes are required to convert the `add` op to its out # variant, along with some other transformations. for method_name, method_input in method_name_to_args.items(): - module = WrapperModule(getattr(eager_module, method_name)) + # if not isinstance(eager_module, torch.nn.Module): exported_methods[method_name] = export( - module, + eager_module, method_input, dynamic_shapes=method_name_to_dynamic_shapes[method_name] if method_name_to_dynamic_shapes