Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced generators with grad-mode decorators #49017

Closed
wants to merge 7 commits into from
181 changes: 181 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,187 @@ def no_grad_context_manager_recursive(depth):
enable_grad_context_manager_recursive(10)
self.assertFalse(torch.is_grad_enabled())

def test_set_grad_coroutines(self):
@torch.no_grad()
def coro_no_grad(n=10):
self.assertFalse(torch.is_grad_enabled())
for i in range(n):
self.assertFalse(torch.is_grad_enabled())
r = yield i
self.assertFalse(torch.is_grad_enabled())
self.assertEqual(i, r)
self.assertFalse(torch.is_grad_enabled())

@torch.enable_grad()
def coro_enable_grad(n=10):
self.assertTrue(torch.is_grad_enabled())
for i in range(n):
self.assertTrue(torch.is_grad_enabled())
r = yield i
self.assertTrue(torch.is_grad_enabled())
self.assertEqual(i, r)
self.assertTrue(torch.is_grad_enabled())

with torch.enable_grad():
self.assertTrue(torch.is_grad_enabled())
coro, r = coro_no_grad(), None
try:
while True:
self.assertTrue(torch.is_grad_enabled())
r = coro.send(r)
self.assertTrue(torch.is_grad_enabled())

except StopIteration:
pass

with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
coro, r = coro_enable_grad(), None
try:
while True:
self.assertFalse(torch.is_grad_enabled())
r = coro.send(r)
self.assertFalse(torch.is_grad_enabled())

except StopIteration:
pass

def test_set_grad_coroutines_benign_exceptions(self):
class RecoverableException(Exception):
pass

@torch.no_grad()
def coro_no_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertFalse(torch.is_grad_enabled())
yield (-i if has_raised else i)

except RecoverableException:
self.assertFalse(torch.is_grad_enabled())
has_raised = True

@torch.enable_grad()
def coro_enable_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i)

except RecoverableException:
self.assertTrue(torch.is_grad_enabled())
has_raised = True

with torch.enable_grad():
coro = coro_no_grad()
assert 0 == next(coro)
try:
while True:
r = coro.throw(RecoverableException)
self.assertLess(r, 0)

except StopIteration:
pass

with torch.no_grad():
coro = coro_enable_grad()
assert 0 == next(coro)
try:
while True:
r = coro.throw(RecoverableException)
self.assertLess(r, 0)

except StopIteration:
pass

def test_set_grad_coroutines_critical_exceptions(self):
class UnrecoverableException(Exception):
pass

class SecondaryException(Exception):
pass

@torch.no_grad()
def coro_no_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertFalse(torch.is_grad_enabled())
yield (-i if has_raised else i)

except UnrecoverableException:
self.assertFalse(torch.is_grad_enabled())
raise SecondaryException

@torch.enable_grad()
def coro_enable_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i)

except UnrecoverableException:
self.assertTrue(torch.is_grad_enabled())
raise SecondaryException

with torch.enable_grad():
coro = coro_no_grad()
assert 0 == next(coro)
with self.assertRaises(SecondaryException):
coro.throw(UnrecoverableException)

with torch.no_grad():
coro = coro_enable_grad()
assert 0 == next(coro)
with self.assertRaises(SecondaryException):
coro.throw(UnrecoverableException)

def test_set_grad_coroutines_exit(self):
@torch.no_grad()
def coro_no_grad(state):
for i in range(10):
try:
self.assertFalse(torch.is_grad_enabled())
yield i

except GeneratorExit:
self.assertFalse(torch.is_grad_enabled())
state.add('GeneratorExit')
raise

@torch.enable_grad()
def coro_enable_grad(state):
for i in range(10):
try:
self.assertTrue(torch.is_grad_enabled())
yield i

except GeneratorExit:
self.assertTrue(torch.is_grad_enabled())
state.add('GeneratorExit')
raise

state = set()
with torch.enable_grad():
coro = coro_no_grad(state)
for i in range(5):
next(coro)

coro.close()
self.assertTrue('GeneratorExit' in state)

state = set()
with torch.no_grad():
coro = coro_enable_grad(state)
for i in range(5):
next(coro)

coro.close()
self.assertTrue('GeneratorExit' in state)

def test_no_grad_python_function(self):
"""Python Functions should respect grad mode."""
x = torch.ones(5, 5, requires_grad=True)
Expand Down
48 changes: 41 additions & 7 deletions torch/autograd/grad_mode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import torch
import functools
import inspect
Expand Down Expand Up @@ -31,13 +32,46 @@ def _wrap_generator(self, func):
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
while True:
try:
with self.__class__():
x = next(gen)
yield x
except StopIteration:
break

# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
cls = type(self)
try:
# Issuing `None` to a generator fires it up
with cls():
response = gen.send(None)
ivannz marked this conversation as resolved.
Show resolved Hide resolved

while True:
try:
# Forward the response to our caller and get its next request
request = yield response

except GeneratorExit:
# Inform the still active generator about its imminent closure
with cls():
gen.close()
raise

except BaseException:
# Propagate the exception thrown at us by the caller
with cls():
response = gen.throw(*sys.exc_info())

else:
# Pass the last request to the generator and get its response
with cls():
response = gen.send(request)

# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the stop iteration have a value here? The previous code was always returning None here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StopIteration is a special exception that is guaranteed to have a value. Quoting from the docs

When a generator or coroutine function returns, a new StopIteration instance is raised, and the value returned by the function is used as the value parameter to the constructor of the exception.

Therefore, I think, for completeness we should propagate the .value as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. So the previous code was wrong to ignore that value? Do you have a python code sample that would show how this value is propagated?

I am trying to think if this PR is going to be BC-breaking for any important usage of the current code? Because functions wrapped in this decorator where already returning the right value on exit right?

Copy link
Contributor Author

@ivannz ivannz Dec 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not call it "wrong", since I haven't come across any implementation in which the generator explicitly communicates with the caller through a StopIteration value, i.e. by return something. Nevertheless, python specifications state that such behaviour is allowed and documented.

I completely agree, that any modification of the generator behaviour by all means MUST be 100% backwards compatible, so extra care must be taken. However, after studying the docs and the relevant PEPs I cam to the conclusion, that this enhanced generator protocol is compatible with the common for expressions and the next function.


return generator_context

def __enter__(self) -> None:
Expand Down