Skip to content

Commit

Permalink
[fx][split] make sure we copy node.meta over during split (#107248)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #107248

Previously when we create placeholder nodes for sub graph modules, we didn't copy node.meta over.

bypass-github-pytorch-ci-checks
bypass-github-export-checks
force-merge-on-github

Test Plan: CI

Reviewed By: houseroad, sayitmemory

Differential Revision: D48330866

fbshipit-source-id: 57e2f8205695d1bd68531ebcd7a7b09b24ab98f9
  • Loading branch information
842974287 authored and facebook-github-bot committed Aug 21, 2023
1 parent ad07a4b commit f7ae570
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,7 @@ exclude_patterns = [
'test/fx/test_source_matcher_utils.py',
'test/fx/test_subgraph_rewriter.py',
'test/fx/test_z3_gradual_types.py',
'test/fx/test_fx_split.py',
'test/jit/__init__.py',
'test/jit/_imported_class_test/__init__.py',
'test/jit/_imported_class_test/bar.py',
Expand Down
32 changes: 32 additions & 0 deletions test/fx/test_fx_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Owner(s): ["module: fx"]

import torch
from torch.fx.passes.split_utils import split_by_tags

from torch.testing._internal.common_utils import TestCase


class TestFXSplit(TestCase):
def test_split_preserve_node_meta(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
x = x + x
y = y * y
return x - y

gm = torch.fx.symbolic_trace(TestModule())
for node in gm.graph.nodes:
node.meta["name"] = node.name
if node.name == "add":
node.tag = "a"
elif node.name == "mul":
node.tag = "b"
elif node.name == "sub":
node.tag = "c"

split_gm = split_by_tags(gm, ["a", "b", "c"])
for m in split_gm.children():
for n in m.graph.nodes:
if n.op != "output":
self.assertIn("name", n.meta)
self.assertEqual(n.meta["name"], n.name)
18 changes: 10 additions & 8 deletions torch/fx/passes/split_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import copy
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.graph import map_arg
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module

from .tools_common import NodeList
from torch.fx._compatibility import compatibility
from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule

__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags']
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]

@compatibility(is_backward_compatible=False)
def getattr_recursive(obj, name):
Expand Down Expand Up @@ -205,14 +207,14 @@ def remap_func(x):
# as a placeholder in current component's graph.
if x not in comp.orig_inputs:
comp.orig_inputs.append(x)
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
placeholder.meta = copy.copy(x.meta)
comp.input_placeholders.append(
comp.graph.placeholder(x.name, type_expr=x.type)
placeholder
)
used_in_main[x] = None

return comp.input_placeholders[
next(i for i, y in enumerate(comp.orig_inputs) if x is y)
]
return comp.input_placeholders[comp.orig_inputs.index(x)]

n = comp.graph.node_copy(node, remap_func)
n.tag = node.tag # type: ignore[attr-defined]
Expand Down

0 comments on commit f7ae570

Please sign in to comment.