Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))

# Traced graph contains a WrapWithSetGradEnabled hop but
# dynamo doesn't support the hop yet so the test fails in strict_mode when re-tracing.
@testing.expectedFailureRetraceability
def test_setgrad_lifted_tensor(self):
class M(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -1488,7 +1485,6 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
return (add,)""",
)

@testing.expectedFailureRetraceability # Unexpected type in sourceless builder torch._higher_order_ops.wrap.WrapWithSetGradEnabled
def test_set_grad_empty(self):
class M(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -5830,6 +5826,26 @@ def _test(m, non_persistent_buffer):
_test(MyModule(), "foo")
_test(MyOuterModule(), "inner.foo")

def test_export_with_set_grad_enabled(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
with torch.no_grad():
return self.linear(x)

model = Model()
ep = export(model, (torch.randn(4, 4),), {})
# _export_for_traininig is using pre_dispatch=False
# Therefore the set_grad calls are not replaced with a hop.
if not is_training_ir_test(self._testMethodName):
self.assertIn(
"torch.ops.higher_order.wrap_with_set_grad_enabled",
ep.graph_module.code,
)

def test_export_as_backend(self):
def f(x, y):
return x + y
Expand Down
75 changes: 75 additions & 0 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def make(value, source=None, **kwargs):
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "wrap_with_set_grad_enabled":
return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs)
else:
unimplemented(f"HigherOrderOperator {value.__name__}")

Expand Down Expand Up @@ -1356,6 +1358,79 @@ def call_function(
)


class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable):
"""
This hop is not exposed to users but is inserted into the graph
after export as a post-processing step.
"""

def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))

if kwargs:
unimplemented(
f"wrap_with_set_grad_enabled: Got unexpected kwargs: {list(kwargs.keys())}"
)

grad_enabled, fn_var, *rest_args = args

if not isinstance(grad_enabled, ConstantVariable):
unimplemented("grad_enabled must be a constant")

_check_supported_callable_arg(tx, fn_var, "enable_grad_fn")

with torch.set_grad_enabled(grad_enabled.as_python_constant()):
(
(body_r, treespec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
fn_var,
[*rest_args],
{},
"torch.ops.higher_order.wrap_with_set_grad_enabled",
source_target=self.value,
set_subgraph_inputs="manual",
should_flatten_outputs=True,
)

if len(body_lifted_freevars) > 0:
unimplemented(
f"wrap_with_set_grad_enabled: Got unexpected freevars {body_lifted_freevars}"
)

body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = add_subgraph(
tx,
"wrap_body",
body_gmod,
)

body_node = make_attr(tx, body_name)

proxy_args = tuple(
[
grad_enabled.as_python_constant(),
body_node,
]
+ [operand.as_proxy() for operand in rest_args]
)
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
return _call_function_and_unflatten_output(
tx, self.value, proxy_args, {}, example_value, treespec
)


class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self,
Expand Down