Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/fx/test_fx_xform_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,32 @@ def forward(self, x):
self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0)

@torch._inductor.config.patch("trace.provenance_tracking", True)
def test_graph_transform_observer_replace(self):
# the node sohuld should not be duplicated
class Model(torch.nn.Module):
def forward(self, x):
y = x + 1
z = y * 2
w = y * 3
return z, w

model = Model()
gm = symbolic_trace(model)

with GraphTransformObserver(gm, "test"):
for node in gm.graph.nodes:
if node.name == "add":
new_node = gm.graph.call_function(
torch.ops.aten.add.Tensor, (node.args[0], node.args[1])
)
node.replace_all_uses_with(new_node)
new_node.name = "new_add"

self.assertEqual(len(new_node.meta["from_node"]), 1)
self.assertEqual(new_node.meta["from_node"][0].name, "add")
self.assertEqual(new_node.meta["from_node"][0].pass_name, "test")


if __name__ == "__main__":
raise RuntimeError(
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,8 @@ def create_mapping_pre_post_grad_nodes(
return empty_return

if not isinstance(pre_grad_graph_id, int):
log.error("Provenance tacking error: pre_grad_graph_id is not an int")
# pre_grad_graph_id may be empty if there's no pre_grad graph
# and there's only a backward graph from backward pass engine
return empty_return

pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
Expand Down
6 changes: 6 additions & 0 deletions torch/fx/passes/graph_transform_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def on_node_replace(old: Node, new: str, user: Node):

assert isinstance(new_node, Node)

# replace hook is called once for each user of old
# this avoids adding duplicated source nodes
added_nodes = {s.name for s in new_node.meta.get("from_node", [])}
if old.name in added_nodes:
return

action = [NodeSourceAction.REPLACE]
if new_node.name in self.created_nodes:
action.append(NodeSourceAction.CREATE)
Expand Down
Loading