diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index c85848cd908b..eb40be4d5c5f 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1186,6 +1186,40 @@ def inner(x, y): # get_item call created by the flatten/unflatten logic in HOP speculation. self.assertEqual(cnt.op_count, ifdynstaticdefault(3, 4)) + def test_map_lowers_to_graph(self): + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + def fn(x, y): + def inner(x, y): + return torch.sin(x + y) + + return control_flow.map(inner, x, y.size(0)) + + x = torch.randn(3, 1) + y = torch.randn(3, 1) + compiled_fn = torch.compile(fn, backend=backend)(x, y) + self.assertEqual(len(backend.graphs), 1) + graph = backend.graphs[0] + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + # TODO(yidi): remove the getitem = l_x_.__getitem__(0) call. It's + # created accidently when we create sample inputs based on the 0-th slice + # before specualting the f in MapHigherOrder. + self.assertExpectedInline( + graph.code.strip(), + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + getitem = l_x_.__getitem__(0) + map_body_0 = self.map_body_0 + map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None + getitem_1 = map_impl[0]; map_impl = None + return (getitem_1,)""", + ) + def test_cond_subgraph_name_is_valid(self): backend = EagerAndRecordGraphs() cnt = CompileCounterWithBackend(backend) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 3e519e0437cd..cfcb9b935589 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -765,11 +765,12 @@ def call_function( body_node = make_attr(tx, body_name) p_args = ( body_node, + 1, # right now we only supports num_mapped = 1 *(arg.as_proxy() for arg in args[1:]), *(arg for arg in body_lifted_freevars.keys()), ) return _call_function_and_unflatten_output( - tx, self.value, p_args, {}, body_r, body_spec + tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec )