Skip to content

Conversation

jvmncs
Copy link
Contributor

@jvmncs jvmncs commented May 21, 2018

Extends the context manager classes torch.no_grad, torch.enable_grad, and torch.set_grad_mode to function as decorators, so that users can wrap functions that will never require a call to .backward() downstream. I've modified the docs to reflect this change, and I've also added tests for the new functionality in each mode's respective test in test_autograd.py.

I also didn't find a unit test specifically for torch.enable_grad. Assuming that's intended, unless my ctrl+f missed it.

@@ -29,13 +38,23 @@ def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False

def __call__(self, func):
def decorate_no_grad(*args, **kwargs):

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if test passes

return x * 2

y = doubler_with(x)
self.assertTrue(y.requires_grad)

This comment was marked as off-topic.

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented May 22, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented May 22, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented May 22, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented May 22, 2018

@pytorchbot test this please

@ssnl ssnl closed this May 22, 2018
@ssnl ssnl reopened this May 22, 2018
@ssnl
Copy link
Collaborator

ssnl commented May 22, 2018

test failure looks legit:

19:56:37 FAIL: test_set_grad_enabled (test_autograd.TestAutograd)
19:56:37 ----------------------------------------------------------------------
19:56:37 Traceback (most recent call last):
19:56:37   File "test_autograd.py", line 1812, in test_set_grad_enabled
19:56:37     self.assertTrue(y.requires_grad)
19:56:37 AssertionError: False is not true

@jvmncs
Copy link
Contributor Author

jvmncs commented May 23, 2018

ah, indeed. going to have to change up the set_grad_enabled class. nice catch

@jvmncs
Copy link
Contributor Author

jvmncs commented May 23, 2018

so, question. in order to maintain the current behavior, e.g. to use torch.set_grad_enabled(False) imperatively, I need to be able to save the input to set_grad_enabled as an attribute, then change enter to set the grad mode to that attribute's value:

def __init__(self, mode):
    self.prev = torch.is_grad_enabled()
    torch._C.set_grad_enabled(mode)
    self.mode = mode

def __enter__(self):
    torch.set_grad_enabled(self.mode)

def __exit__(self, *args):
    torch.set_grad_enabled(self.prev)
    return False

however, if I do that, then any time it's instantiated, including when it wraps a function, the underlying grad mode will be changed. this seems like unwanted behavior, and maybe it's best to not use this one as a decorator? open to other suggestions.

@apaszke
Copy link
Contributor

apaszke commented May 23, 2018

I think it's ok to forbid using set_grad_mode as a decorator, since its generally useful when run-time values are changing, but function definition usually happens only once at the beginning of the program.

@ssnl ssnl merged commit 4352eab into pytorch:master May 23, 2018
petrex pushed a commit to petrex/pytorch that referenced this pull request May 31, 2018
* origin:
  [Caffe2] Enabling AMD GPU Backend for Caffe2 (pytorch#7566)
  Call grad_mode.py context managers as decorators (pytorch#7737)
  catch CPU tensors in checkSameGPU (fixes pytorch#7689) (pytorch#7767)
  Mark stack as non-executable in NNPACK (pytorch#7752)
  small fixes in fusion_compiler (pytorch#7776)
  Run clang-format on c10d (pytorch#7791)
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* call grad_mode.py context managers as decorators

* flake fixes

* switch to using context manager in wrapper

* fix set_grad_enabled test

* removed dumb github UI whitespace

* revert set_grad_enabled to normal, update tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants