New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhanced generators with grad-mode decorators #49017
Conversation
Hi @ivannz! Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
💊 CI failures summary and remediationsAs of commit b077c4b (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. This comment has been revised 29 times. |
42f0aee
to
2dc287b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact sys.exc_info docs state that the traceback can be None
only if no exception is being currently handled anywhere on the stack, which cannot be the case in the highlighted except
clause.
0dee7aa
to
d8ef729
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR and all the tests!
This looks mostly ok to me except small details listed below.
with grad_mode_context: | ||
response = gen.send(request) | ||
except StopIteration as e: | ||
return e.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the stop iteration have a value here? The previous code was always returning None here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StopIteration is a special exception that is guaranteed to have a value. Quoting from the docs
When a generator or coroutine function returns, a new StopIteration instance is raised, and the value returned by the function is used as the value parameter to the constructor of the exception.
Therefore, I think, for completeness we should propagate the .value
as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. So the previous code was wrong to ignore that value? Do you have a python code sample that would show how this value is propagated?
I am trying to think if this PR is going to be BC-breaking for any important usage of the current code? Because functions wrapped in this decorator where already returning the right value on exit right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not call it "wrong", since I haven't come across any implementation in which the generator explicitly communicates with the caller through a StopIteration value, i.e. by return something
. Nevertheless, python specifications state that such behaviour is allowed and documented.
I completely agree, that any modification of the generator behaviour by all means MUST be 100% backwards compatible, so extra care must be taken. However, after studying the docs and the relevant PEPs I cam to the conclusion, that this enhanced generator protocol is compatible with the common for
expressions and the next
function.
d8ef729
to
42efb4e
Compare
@albanD The following related to our discussion about the value of The following example shows, what returning from a generator does: def gen():
for i in range(10):
yield i
# return with a value is automatically converted to StopIteration
return 10 from the docs
And from the docs for return
for i in gen():
print(i) whereas it = gen()
try:
while True:
print(next(it))
except StopIteration as e:
print(type(e), e, e.value) For example here is a generator without a def gen():
for i in range(10):
yield i
it = gen()
try:
while True:
print(next(it))
except StopIteration as e:
print(type(e), e, e.value) |
I admit, not merging this PR won't break much, though: it is possible to rewrite a generator's code in such a way as to apply the grad-mode context manager to the critical sections inside the generator. At the same time, it is conceivable that the following situation might arise (which is not related per se to this PR): # ctx.py
import torch
def gen():
with torch.no_grad():
yield 1 # some compute, e.g. torch.matmul(...)
yield 2
with torch.enable_grad():
it = gen()
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled()) outputs
This non-intuitive behaviour happens, because the generator is paused inside the no-grad context: grads were disabled in The following code illustrates this (the same can be seen by tracing the example above with class ctx:
def __enter__(self):
print('>>> enter')
def __exit__(self, typ, val, tb):
print('<<< exit')
def gen():
with ctx():
yield 1
yield 2
it = gen()
print(next(it))
print('do something')
print(next(it)) prints
I guess, independently of this PR, it might be necessary to add a warning of this somewhat non-intuitive behaviour in the docs of pytorch related to grad mode context managers in generators. |
But this PR would fix the behavior right? I was asking about BC-breaking change mostly to know how to document this PR in the release notes. Fixing bad behavior is definitely worth doing BC-breaking changes! |
@albanD Unfortunately, this PR cannot address this correct, albeit non-intuitive behaviour. All it does is to allow one to wrap an enhanced generator function as a whole in a grad-mode context, letting two-way communication through and handling In general, this kind of behaviour, as you have seen in the
If I understand correctly, By adding a grad-mode decorator to a generator or an ordinary procedure the user "enters" in a sort of a "contract" with the library, expecting the grad mode to be properly set every time the execution flow is within the body of that function. This is precisely why the implementation in this PR is worded with a edits: fixed typos and wrote slightly more verbose description |
From what I see, the fact that you wrap only the # ctx.py
import torch
def gen():
print("inside 0 ", torch.is_grad_enabled())
yield 1
print("inside 1 ", torch.is_grad_enabled())
yield 2
print("inside 2 ", torch.is_grad_enabled())
def no_grad_wrapper(gen):
def wrapped():
g = gen()
with torch.no_grad():
resp = g.send(None)
while True:
# Simplified version of the wrapper in this PR that
# wraps the send but not the yield
req = yield resp
with torch.no_grad():
resp = g.send(req)
return wrapped
print("default:")
with torch.enable_grad():
it = gen()
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled())
print("")
print("wrapped:")
with torch.enable_grad():
it = no_grad_wrapper(gen)()
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled())
print(next(it))
print(torch.is_grad_enabled()) Would output:
But your tests already cover this case right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code is good to merge.
A couple comments capturing the discussions we had here would be great before we merge!
42efb4e
to
7f0b7fc
Compare
@albanD I have updated the test suite to better verify that grad mode is correctly set inside a generator and that exceptions and special conditions are correctly propagated. Now tests show more clearly the difference between the old implementation and the one from this PR. The use-case you provide is tested by the existing test_set_grad_generator_functions, which uses Other tests verify that the grad mode is correctly set when the execution flow is inside the body of the generator: specifically, special conditions such as
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. The inline comments look great!
Can you just fix the bad default arguments in the test? (that will fix the lint)
…g with proper grad-mode
…pagated into the wrapped generator
7f0b7fc
to
b077c4b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #49017 +/- ##
==========================================
- Coverage 80.63% 80.63% -0.01%
==========================================
Files 1875 1875
Lines 202714 202726 +12
==========================================
+ Hits 163453 163458 +5
- Misses 39261 39268 +7 |
Summary: This PR addresses the feature request outlined in pytorch#48713 for two-way communication with enhanced generators from [pep-342](https://www.python.org/dev/peps/pep-0342/). Briefly, the logic of the patch resembles `yield from` [pep-380](https://www.python.org/dev/peps/pep-0380/), which cannot be used, since the generator **must be interacted with from within the grad-mode context**, while yields from the decorator **must take place outside of the context**. Hence any interaction with the wrapped generator, be it via [.send](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.send), [.throw](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.throw), and even [.close](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.close) must be wrapped by a `with` clause. The patch is compatible with `for i in gen: pass` and `next(gen)` use cases and allows two-way communication with the generator via `.send <-> yield` points. ### Logic At lines [L37-L38](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L37-L38) we (the decorator) **start the wrapped generator** (coroutine) by issuing `None` into it (equivalently, we can use `next(get)` here). Then we **dispatch responses of the generator** to our ultimate caller and **relay the latter's requests** into the generator in the loop on lines [L39-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L39-L52). We yield the most recent response on [L40-L41](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L40-L41), at which point we become **paused**, waiting for the next ultimate caller's interaction with us. If the caller **sends us a request**, then we become unpaused and move to [L51-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L51-L52) and **forward it into the generator**, at which point we pause, waiting for its response. The response might be a value, an exception or a `StopIteration`. In the case of an exception from the generator, we let it **bubble up** from the immediately surrounding [except clause](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement) to the ultimate caller through the [outer try-except](https://github.com/ivannz/pytorch/blob/2dc287bba87fa6f05c49446c0239ffdcdb1e896e/torch/autograd/grad_mode.py#L36-L54). In the case of a `StopIteration`, we **take it's payload and propagate it** to the caller via [return](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L54). In the case of a value, the flow and the loop continues. The caller **throwing an exception at us** is handled much like a proper request, except for the exception playing the role of the request. In this case we **forward it into the generator** on lines [L47-L49](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L47-L49) and await its response. We explicitly **advance** the traceback one frame up, in order to indicate the **source of the exception within the generator**. Finally the `GeneratorExit` is handled on lines [L42-L45](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L42-L45) and closes the generator. Updates: clarified exception propagation Pull Request resolved: pytorch#49017 Reviewed By: izdeby Differential Revision: D25567796 Pulled By: albanD fbshipit-source-id: 801577cccfcb2b5e13a08e77faf407881343b7b0
This PR addresses the feature request outlined in #48713 for two-way communication with enhanced generators from pep-342.
Briefly, the logic of the patch resembles
yield from
pep-380, which cannot be used, since the generator must be interacted with from within the grad-mode context, while yields from the decorator must take place outside of the context. Hence any interaction with the wrapped generator, be it via .send, .throw, and even .close must be wrapped by awith
clause. The patch is compatible withfor i in gen: pass
andnext(gen)
use cases and allows two-way communication with the generator via.send <-> yield
points.Logic
At lines L37-L38 we (the decorator) start the wrapped generator (coroutine) by issuing
None
into it (equivalently, we can usenext(get)
here). Then we dispatch responses of the generator to our ultimate caller and relay the latter's requests into the generator in the loop on lines L39-L52.We yield the most recent response on L40-L41, at which point we become paused, waiting for the next ultimate caller's interaction with us. If the caller sends us a request, then we become unpaused and move to L51-L52 and forward it into the generator, at which point we pause, waiting for its response. The response might be a value, an exception or a
StopIteration
. In the case of an exception from the generator, we let it bubble up from the immediately surrounding except clause to the ultimate caller through the outer try-except. In the case of aStopIteration
, we take it's payload and propagate it to the caller via return. In the case of a value, the flow and the loop continues.The caller throwing an exception at us is handled much like a proper request, except for the exception playing the role of the request. In this case we forward it into the generator on lines L47-L49 and await its response. We explicitly advance the traceback one frame up, in order to indicate the source of the exception within the generator.
Finally the
GeneratorExit
is handled on lines L42-L45 and closes the generator.Updates: clarified exception propagation