Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,40 @@ def inner(x, y):
# get_item call created by the flatten/unflatten logic in HOP speculation.
self.assertEqual(cnt.op_count, ifdynstaticdefault(3, 4))

def test_map_lowers_to_graph(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)

def fn(x, y):
def inner(x, y):
return torch.sin(x + y)

return control_flow.map(inner, x, y.size(0))

x = torch.randn(3, 1)
y = torch.randn(3, 1)
compiled_fn = torch.compile(fn, backend=backend)(x, y)
self.assertEqual(len(backend.graphs), 1)
graph = backend.graphs[0]
# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return

# TODO(yidi): remove the getitem = l_x_.__getitem__(0) call. It's
# created accidently when we create sample inputs based on the 0-th slice
# before specualting the f in MapHigherOrder.
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
getitem = l_x_.__getitem__(0)
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""",
)

def test_cond_subgraph_name_is_valid(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,12 @@ def call_function(
body_node = make_attr(tx, body_name)
p_args = (
body_node,
1, # right now we only supports num_mapped = 1
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars.keys()),
)
return _call_function_and_unflatten_output(
tx, self.value, p_args, {}, body_r, body_spec
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
)


Expand Down