Skip to content

Commit

Permalink
aten::set_grad_enabled should not push as it does not return a value (#…
Browse files Browse the repository at this point in the history
…45559)

Summary:
Fixes #45558

This assertion failure is caused by the incorrect implementation of ``aten::set_grad_enabled`` in [torch/csrc/jit/runtime/register_special_ops.cpp](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/register_special_ops.cpp#L436). The current implementation is:

```cpp
Operator(
    "aten::set_grad_enabled(bool val) -> ()",
    [](Stack* stack) {
      torch::GradMode::set_enabled(pop(stack).toBool());
      push(stack, IValue());
    },
    aliasAnalysisConservative()),
```

which push a ``None`` on to the evaluation stack after calling ``set_enabled``. But according to the signature, the behavior is incorrect as the signature says this function won't return a value. I guess the original author might be confused by the behavior of Python, which pushes a ``None`` on to the evaluation stack when the function definition does not end with a return statement with an explicit result value.

If ``aten::set_grad_enabled`` pushes a ``None`` on to the evaluation stack, each time it's called, the evaluation stack will accumulate an extra ``None``. In our case, ``with torch.no_grad():`` will cause ``aten::set_grad_enabled`` to be called twice, so when the ``forward`` method finishes, the evaluation stack will be ``[None, None, Tensor]``. But the return statement of ``GraphFunction::operator()`` in [torch/csrc/jit/api/function_impl.cpp](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/function_impl.cpp#L51) is ``return stack.front();`` which will try to extract a tensor out of a ``None`` thus causes the assertion failure.

The solution is simple, just remove the push in the implementation of ``aten::set_grad_enabled``.

Pull Request resolved: #45559

Reviewed By: albanD

Differential Revision: D24142153

Pulled By: SplitInfinity

fbshipit-source-id: 75aad0e38bd912a437f7e1a1ee89ab4445e35b5d
  • Loading branch information
huaidong.xiong authored and facebook-github-bot committed Oct 8, 2020
1 parent ddcacc7 commit e3112e3
Showing 1 changed file with 0 additions and 1 deletion.
1 change: 0 additions & 1 deletion torch/csrc/jit/runtime/register_special_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ RegisterOperators reg({
"aten::set_grad_enabled(bool val) -> ()",
[](Stack* stack) {
torch::GradMode::set_enabled(pop(stack).toBool());
push(stack, IValue());
},
aliasAnalysisConservative()),
});
Expand Down

0 comments on commit e3112e3

Please sign in to comment.