Skip to content

Commit cc67a92

Browse files
jjsjann123pytorchmergebot
authored andcommitted
fixing call_module on subscripting into generator (#81258)
named_modules() return a generator, which is not subscriptable and causes node support query to fail Pull Request resolved: #81258 Approved by: https://github.com/SherlockNoMad
1 parent dd73c97 commit cc67a92

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

test/test_fx_backends.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,38 @@ def _generate_random_inputs(self, device, inputs_meta: List[Tuple[torch.Size, to
130130
return inputs
131131

132132

133+
@skipCUDAIfRocm
134+
@dtypes(torch.float32)
135+
def test_nvfuser_call_module_backend(self, device, dtype):
136+
137+
class Model(torch.nn.Module):
138+
139+
def __init__(self):
140+
super(Model, self).__init__()
141+
self.bn = torch.nn.BatchNorm2d(3)
142+
self.relu = torch.nn.ReLU()
143+
144+
def forward(self, inp):
145+
o = self.bn(inp)
146+
o = self.relu(o)
147+
return o
148+
149+
inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
150+
m = Model().to(dtype=dtype, device=device)
151+
152+
# note that the traced module here contains only `call_module` node,
153+
# which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error
154+
traced = symbolic_trace(m)
155+
156+
nvfuser = NvFuserBackend()
157+
compiled_module = nvfuser.compile(traced)
158+
159+
eager_result = m(inp)
160+
nvfuser_result = compiled_module(inp)
161+
162+
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
163+
164+
133165
@skipCUDAIfRocm
134166
@dtypes(torch.float32)
135167
def test_nvfuser_backend(self, device, dtype):

torch/fx/passes/infra/partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __get_supported_nodes(self) -> NodeList:
9090
logging.debug("Collecting supported nodes...")
9191
supported_nodes = []
9292
for node in self.graph_module.graph.nodes:
93-
if self.operator_support.is_node_supported(self.graph_module.named_modules(), node):
93+
if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node):
9494
supported_nodes.append(node)
9595
return supported_nodes
9696

0 commit comments

Comments
 (0)