Skip to content

Commit

Permalink
[inductor] Make clone_graph copy node name as well (#103409)
Browse files Browse the repository at this point in the history
Summary: This solves an inconsistency between two-pass fusion results
when turning on cpp wrapper. The unit test comes from yolov3.

Pull Request resolved: #103409
Approved by: https://github.com/eellison, https://github.com/jansel
  • Loading branch information
desertfire authored and pytorchmergebot committed Jun 14, 2023
1 parent 7a2a006 commit 2d745b9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/inductor/test_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class DynamicShapesCudaWrapperCudaTests(TorchTestCase):

# see https://github.com/pytorch/pytorch/issues/103194
test_failures_cuda_wrapper = {
"test_batch_norm_2d_2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
("cuda_wrapper",)
),
"test_fft_real_input_cuda_dynamic_shapes": test_torchinductor.TestFailure(
("cuda_wrapper",)
),
Expand Down Expand Up @@ -223,6 +226,7 @@ class BaseTest(NamedTuple):
# Maintain two separate test lists for cuda and cpp for now
for item in [
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_batch_norm_2d_2"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),
Expand Down
37 changes: 37 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,43 @@ def test_batch_norm_2d(self):
check_lowp=False, # too painful to match types of bn model
)

@skipIfRocm
def test_batch_norm_2d_2(self):
if self.device == "cpu":
raise unittest.SkipTest("requires CUDA")

class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_0 = torch.nn.Conv2d(
64,
128,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False,
)
self.self_1 = torch.nn.BatchNorm2d(
128,
eps=0.0001,
momentum=0.03,
affine=True,
track_running_stats=True,
)
self.self_2 = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True)

def forward(self, l_input_: torch.Tensor):
self_0 = self.self_0(l_input_)
self_1 = self.self_1(self_0)
self_2 = self.self_2(self_1)
return (self_2,)

inp = torch.randn((4, 64, 192, 256), dtype=torch.float32, device="cuda")
mod = Repro().cuda()
o1 = mod(inp)
o2 = torch.compile(mod)(inp)
self.assertEqual(o1, o2)

def test_layer_norm(self):
m = torch.nn.Sequential(
torch.nn.LayerNorm(32),
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,9 @@ def run_node(self, old_node):
new_node = super().run_node(old_node)
if isinstance(new_node, torch.fx.Proxy):
new_node.node.meta.update(old_node.meta)
new_node.node.name = self.new_graph._graph_namespace.create_name(
old_node.name, None
)
return new_node

return CopyGraph(input_graph).transform()
Expand Down

1 comment on commit 2d745b9

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #103409 on behalf of https://github.com/osalpekar due to torchbench regression starting this commit. See https://www.torch-ci.com/pytorch/pytorch/commit/2d745b95d723641e575027bd4e2fff612f61cc8f for more info (comment)

Please sign in to comment.