Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Fix submodule naming for subgraph split #47869

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 0 additions & 38 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental import shape_prop
from torch.fx.experimental.subgraph_creation_example import split_module
from torch.fx.immutable_collections import immutable_dict, immutable_list
from copy import deepcopy

Expand Down Expand Up @@ -880,43 +879,6 @@ def test_inf_nan_kwds(self):
x = torch.rand(3, 4)
self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))

def test_subgraph_creation(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w

# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)

# random mod partitioning
partition_counter = 0
NPARTITIONS = 3

def mod_partition(node: Node):
nonlocal partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition

# split module in module with submodules
module_with_submodules = split_module(my_module_traced, my_module, mod_partition)

x = torch.rand(3, 4)
y = torch.rand(3, 4)

orig_out = my_module_traced(x, y)
submodules_out = module_with_submodules(x, y)

self.assertEqual(orig_out, submodules_out)

def test_deepcopy_recursion_depth(self):
depth = sys.getrecursionlimit() + 20

Expand Down
56 changes: 56 additions & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import unittest
from typing import Dict
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
Expand All @@ -8,6 +9,7 @@
from torch.fx.experimental.rewriter import RewritingTracer
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.fx.experimental.subgraph_creation_example import split_module
from torch.fx.experimental.partitioner_utils import (
NodeLatency,
get_partition_to_latency_mapping,
Expand All @@ -17,6 +19,13 @@
)
from typing import Union, Callable

try:
from torchvision.models import resnet18
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")


def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
return GraphModule(
Expand Down Expand Up @@ -485,6 +494,53 @@ def forward(self, a, b):
# Confirm that the output is correct
self.assertEqual(traced(3, 3), m(3, 3))

def test_subgraph_creation(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w

# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)

# random mod partitioning
partition_counter = 0
NPARTITIONS = 3

def mod_partition(node: Node):
nonlocal partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition

# split module in module with submodules
module_with_submodules = split_module(my_module_traced, my_module, mod_partition)

x = torch.rand(3, 4)
y = torch.rand(3, 4)

orig_out = my_module_traced(x, y)
submodules_out = module_with_submodules(x, y)

self.assertEqual(orig_out, submodules_out)

@skipIfNoTorchVision
def test_subgraph_trivial_resnet(self):
# Smoke test trivially splitting resnet into 1 partition works
# There was an issue before causing submodule names to be aliased
m = resnet18()
traced = symbolic_trace(m)
a = torch.rand(64, 3, 7, 7)
module_with_submodules = split_module(traced,m,lambda node: 0)
module_with_submodules(a)

def test_traceable_function_with_nonstandard_name(self):
def foo(x):
return torch.relu(x)
Expand Down
3 changes: 2 additions & 1 deletion torch/fx/experimental/subgraph_creation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
target = target_atoms[-1]
# target = target_atoms[-1]
target = '_'.join(target_atoms)
partition.targets[target] = target_attr

assert isinstance(gathered_args, tuple)
Expand Down