diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 3544d7e3e12..cd1bb414a66 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -16,7 +16,6 @@ get_buffer, get_lifted_tensor_constant, get_param, - is_buffer, is_lifted_tensor_constant, is_param, ) @@ -78,6 +77,16 @@ def get_data( return None +def is_constant_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """Checks if the given node is a constant buffer.""" + + if node.target not in program.graph_signature.inputs_to_buffers: + return False + fqn = program.graph_signature.inputs_to_buffers[node.target] + # if the buffer is mutated then record that + return fqn not in program.graph_signature.buffers_to_mutate.values() + + def get_constant_placeholder_dict( exported_program: ExportedProgram, ) -> OrderedDict[torch.fx.Node, torch.Tensor]: @@ -85,15 +94,12 @@ def get_constant_placeholder_dict( Returns a dictionary of placeholder node -> constant tensor. """ const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() - for node in exported_program.graph.nodes: - if node.op != "placeholder": - continue - + for node in exported_program.graph.find_nodes(op="placeholder"): if is_param(exported_program, node): const_node_to_tensor[node] = cast( torch.Tensor, get_param(exported_program, node) ) - elif is_buffer(exported_program, node): + elif is_constant_buffer(exported_program, node): const_node_to_tensor[node] = cast( torch.Tensor, get_buffer(exported_program, node) ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index ff076a7345e..5691bf870e2 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1594,6 +1594,34 @@ def forward(self, x): gm.code ) + def test_constant_prop_pass_for_mutable_buffers(self) -> None: + def count_adds(gm: torch.fx.GraphModule) -> int: + return len( + gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + ) + + class MutableStateModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(1)) + + def forward(self, x): + x = x + self.state + # Add 1 (constant) to state. + self.state.add_(1) + return x + + edge_manager = to_edge( + export(MutableStateModule(), (torch.zeros(1),), strict=True) + ) + self.assertEqual(count_adds(edge_manager.exported_program().graph_module), 2) + edge_manager._edge_programs["forward"] = constant_prop_pass( + edge_manager._edge_programs["forward"] + ) + self.assertEqual(count_adds(edge_manager.exported_program().graph_module), 2) + def test_constant_prop_pass_for_no_grad(self) -> None: class LSTM(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers):