Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced generators with grad-mode decorators #49017

Closed
wants to merge 7 commits into from

Conversation

ivannz
Copy link
Contributor

@ivannz ivannz commented Dec 8, 2020

This PR addresses the feature request outlined in #48713 for two-way communication with enhanced generators from pep-342.

Briefly, the logic of the patch resembles yield from pep-380, which cannot be used, since the generator must be interacted with from within the grad-mode context, while yields from the decorator must take place outside of the context. Hence any interaction with the wrapped generator, be it via .send, .throw, and even .close must be wrapped by a with clause. The patch is compatible with for i in gen: pass and next(gen) use cases and allows two-way communication with the generator via .send <-> yield points.

Logic

At lines L37-L38 we (the decorator) start the wrapped generator (coroutine) by issuing None into it (equivalently, we can use next(get) here). Then we dispatch responses of the generator to our ultimate caller and relay the latter's requests into the generator in the loop on lines L39-L52.

We yield the most recent response on L40-L41, at which point we become paused, waiting for the next ultimate caller's interaction with us. If the caller sends us a request, then we become unpaused and move to L51-L52 and forward it into the generator, at which point we pause, waiting for its response. The response might be a value, an exception or a StopIteration. In the case of an exception from the generator, we let it bubble up from the immediately surrounding except clause to the ultimate caller through the outer try-except. In the case of a StopIteration, we take it's payload and propagate it to the caller via return. In the case of a value, the flow and the loop continues.

The caller throwing an exception at us is handled much like a proper request, except for the exception playing the role of the request. In this case we forward it into the generator on lines L47-L49 and await its response. We explicitly advance the traceback one frame up, in order to indicate the source of the exception within the generator.

Finally the GeneratorExit is handled on lines L42-L45 and closes the generator.

Updates: clarified exception propagation

@ivannz ivannz requested a review from albanD as a code owner December 8, 2020 14:19
@facebook-github-bot
Copy link
Contributor

Hi @ivannz!

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@dr-ci
Copy link

dr-ci bot commented Dec 8, 2020

💊 CI failures summary and remediations

As of commit b077c4b (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 29 times.

@ivannz ivannz force-pushed the enh-grad-mode-deco branch 2 times, most recently from 42f0aee to 2dc287b Compare December 8, 2020 15:46
@zhangguanheng66 zhangguanheng66 added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 8, 2020
Copy link
Contributor Author

@ivannz ivannz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact sys.exc_info docs state that the traceback can be None only if no exception is being currently handled anywhere on the stack, which cannot be the case in the highlighted except clause.

torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
@ivannz ivannz force-pushed the enh-grad-mode-deco branch 2 times, most recently from 0dee7aa to d8ef729 Compare December 13, 2020 22:56
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and all the tests!

This looks mostly ok to me except small details listed below.

torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
with grad_mode_context:
response = gen.send(request)
except StopIteration as e:
return e.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the stop iteration have a value here? The previous code was always returning None here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StopIteration is a special exception that is guaranteed to have a value. Quoting from the docs

When a generator or coroutine function returns, a new StopIteration instance is raised, and the value returned by the function is used as the value parameter to the constructor of the exception.

Therefore, I think, for completeness we should propagate the .value as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. So the previous code was wrong to ignore that value? Do you have a python code sample that would show how this value is propagated?

I am trying to think if this PR is going to be BC-breaking for any important usage of the current code? Because functions wrapped in this decorator where already returning the right value on exit right?

Copy link
Contributor Author

@ivannz ivannz Dec 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not call it "wrong", since I haven't come across any implementation in which the generator explicitly communicates with the caller through a StopIteration value, i.e. by return something. Nevertheless, python specifications state that such behaviour is allowed and documented.

I completely agree, that any modification of the generator behaviour by all means MUST be 100% backwards compatible, so extra care must be taken. However, after studying the docs and the relevant PEPs I cam to the conclusion, that this enhanced generator protocol is compatible with the common for expressions and the next function.

torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
@ivannz
Copy link
Contributor Author

ivannz commented Dec 14, 2020

@albanD The following related to our discussion about the value of StopIteration exception.

The following example shows, what returning from a generator does:

def gen():
    for i in range(10):
        yield i
    
    # return with a value is automatically converted to StopIteration
    return 10

from the docs

When a generator or coroutine function returns, a new StopIteration instance is raised, and the value returned by the function is used as the value parameter to the constructor of the exception.

And from the docs for return

In a generator function, the return statement indicates that the generator is done and will cause StopIteration to be raised. The returned value (if any) is used as an argument to construct StopIteration and becomes the StopIteration.value attribute.

for loop ignores payload of StopIteration

for i in gen():
    print(i)

whereas next raises StopIteration with the payload

it = gen()
try:
    while True:
        print(next(it))

except StopIteration as e:
    print(type(e), e, e.value)

For example here is a generator without a return (essentially an implicit return or return None)

def gen():
    for i in range(10):
        yield i

it = gen()
try:
    while True:
        print(next(it))

except StopIteration as e:
    print(type(e), e, e.value)

@ivannz
Copy link
Contributor Author

ivannz commented Dec 14, 2020

I admit, not merging this PR won't break much, though: it is possible to rewrite a generator's code in such a way as to apply the grad-mode context manager to the critical sections inside the generator.

At the same time, it is conceivable that the following situation might arise (which is not related per se to this PR):

# ctx.py
import torch

def gen():
    with torch.no_grad():
        yield 1  # some compute, e.g. torch.matmul(...)
    yield 2

with torch.enable_grad():
    it = gen()
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())

outputs

True
1
False
2
True

This non-intuitive behaviour happens, because the generator is paused inside the no-grad context: grads were disabled in __enter__, but have not been reenabled when the control is given back to the next(it), since no __exit__ has been called. The disabled gad state persists until the succeeding next(it).

The following code illustrates this (the same can be seen by tracing the example above with python -m pdb ctx.py):

class ctx:
    def __enter__(self):
        print('>>> enter')
    def __exit__(self, typ, val, tb):
        print('<<< exit')

def gen():
    with ctx():
        yield 1
    yield 2

it = gen()
print(next(it))
print('do something')
print(next(it))

prints

>>> enter
1
do something
<<< exit
2

I guess, independently of this PR, it might be necessary to add a warning of this somewhat non-intuitive behaviour in the docs of pytorch related to grad mode context managers in generators.

@albanD
Copy link
Collaborator

albanD commented Dec 14, 2020

At the same time, it is conceivable that the following situation might arise (which is not related per se to this PR):

But this PR would fix the behavior right?

I was asking about BC-breaking change mostly to know how to document this PR in the release notes. Fixing bad behavior is definitely worth doing BC-breaking changes!

@ivannz
Copy link
Contributor Author

ivannz commented Dec 15, 2020

@albanD Unfortunately, this PR cannot address this correct, albeit non-intuitive behaviour. All it does is to allow one to wrap an enhanced generator function as a whole in a grad-mode context, letting two-way communication through and handling StopIteration with payload better.

In general, this kind of behaviour, as you have seen in the with-yield example before, is core to the yield expression in python: it pauses the generator, gives back control to the caller, which runs until its next, .send or .throw, upon which the generator is un-paused. It may be that the generator is never resumed and subsequently correctly garbage collected (by being issued a GeneratorExit, which is handled with context by L42-L45). Quoting form the docs on yield expression (emphasis and edits mine):

... execution proceeds to the first yield expression, where it is [paused], returning ... to the generator’s caller. By [paused], we mean that all local state is retained, including the current bindings of local variables, the instruction pointer, the internal evaluation stack, and the state of any exception handling. When the execution is resumed by calling one of the generator’s methods, the [generator] function can proceed exactly as if the yield expression were just another external call.

If I understand correctly, torch._C._set_grad_enabled has a persistent thread-local, but not stack-frame local effect. At the same time yield switches between stack frames within the same thread, thereby resembling cooperative multitasking without preemption. This means that 'exit' would not be called in with-yield block, hence the call to torch._C._set_grad_enabled(mode) upon __enter__ would still be in effect between the suspension-unsuspension, e.g. yield and next.

By adding a grad-mode decorator to a generator or an ordinary procedure the user "enters" in a sort of a "contract" with the library, expecting the grad mode to be properly set every time the execution flow is within the body of that function. This is precisely why the implementation in this PR is worded with a with clause around each call to the wrapped generator's method, at the same time actively avoiding putting yield in a with (otherwise we could've used yield from instead of replicating its logic).

edits: fixed typos and wrote slightly more verbose description

@albanD
Copy link
Collaborator

albanD commented Dec 15, 2020

From what I see, the fact that you wrap only the gen.send and not the yield does solve the issue I was thinking about above:

# ctx.py
import torch

def gen():
    print("inside 0 ", torch.is_grad_enabled())
    yield 1
    print("inside 1 ", torch.is_grad_enabled())
    yield 2
    print("inside 2 ", torch.is_grad_enabled())

def no_grad_wrapper(gen):
    def wrapped():
        g = gen()
        with torch.no_grad():
            resp = g.send(None)

        while True:
            # Simplified version of the wrapper in this PR that
            # wraps the send but not the yield
            req = yield resp

            with torch.no_grad():
                resp = g.send(req)

    return wrapped

print("default:")
with torch.enable_grad():
    it = gen()
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())

print("")
print("wrapped:")
with torch.enable_grad():
    it = no_grad_wrapper(gen)()
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())
    print(next(it))
    print(torch.is_grad_enabled())

Would output:

default:
True
inside 0  True
1
True
inside 1  True
2
True

wrapped:
True
inside 0  False
1
True
inside 1  False
2
True

But your tests already cover this case right?
And if the user does something bad in their own generator, it is expected, as you said, that the wrapper will not fix it.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code is good to merge.
A couple comments capturing the discussions we had here would be great before we merge!

torch/autograd/grad_mode.py Show resolved Hide resolved
@ivannz
Copy link
Contributor Author

ivannz commented Dec 15, 2020

@albanD I have updated the test suite to better verify that grad mode is correctly set inside a generator and that exceptions and special conditions are correctly propagated. Now tests show more clearly the difference between the old implementation and the one from this PR.

The use-case you provide is tested by the existing test_set_grad_generator_functions, which uses next implicitly through the for loop expression. This test and the complementary new test test_set_grad_coroutines both check that the grads are set to the specified mode inside a wrapped generator. The latter test also makes sure that the two-way communication via .send works as expected.

Other tests verify that the grad mode is correctly set when the execution flow is inside the body of the generator: specifically, special conditions such as

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. The inline comments look great!

Can you just fix the bad default arguments in the test? (that will fix the lint)

test/test_autograd.py Outdated Show resolved Hide resolved
test/test_autograd.py Outdated Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented Dec 16, 2020

Codecov Report

Merging #49017 (b077c4b) into master (5912316) will decrease coverage by 0.00%.
The diff coverage is 94.73%.

@@            Coverage Diff             @@
##           master   #49017      +/-   ##
==========================================
- Coverage   80.63%   80.63%   -0.01%     
==========================================
  Files        1875     1875              
  Lines      202714   202726      +12     
==========================================
+ Hits       163453   163458       +5     
- Misses      39261    39268       +7     

@facebook-github-bot
Copy link
Contributor

@albanD merged this pull request in efc0906.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
This PR addresses the feature request outlined in pytorch#48713 for two-way communication with enhanced generators from [pep-342](https://www.python.org/dev/peps/pep-0342/).

Briefly, the logic of the patch resembles `yield from` [pep-380](https://www.python.org/dev/peps/pep-0380/), which cannot be used, since the generator **must be interacted with from within the grad-mode context**, while yields from the decorator **must take place outside of the context**. Hence any interaction with the wrapped generator, be it via [.send](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.send), [.throw](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.throw), and even [.close](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.close) must be wrapped by a `with` clause. The patch is compatible with `for i in gen: pass` and `next(gen)` use cases and allows two-way communication with the generator via `.send <-> yield` points.

### Logic
At lines [L37-L38](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L37-L38) we (the decorator) **start the wrapped generator** (coroutine) by issuing `None` into it (equivalently, we can use `next(get)` here). Then we **dispatch responses of the generator** to our ultimate caller and **relay the latter's requests** into the generator in the loop on lines [L39-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L39-L52).

We yield the most recent response on [L40-L41](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L40-L41), at which point we become **paused**, waiting for the next ultimate caller's interaction with us. If the caller **sends us a request**, then we become unpaused and move to [L51-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L51-L52) and **forward it into the generator**, at which point we pause, waiting for its response. The response might be a value, an exception or a `StopIteration`. In the case of an exception from the generator, we let it **bubble up** from the immediately surrounding [except clause](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement)  to the ultimate caller through the [outer try-except](https://github.com/ivannz/pytorch/blob/2dc287bba87fa6f05c49446c0239ffdcdb1e896e/torch/autograd/grad_mode.py#L36-L54). In the case of a `StopIteration`, we **take it's payload and propagate it** to the caller via [return](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L54). In the case of a value, the flow and the loop continues.

The caller **throwing an exception at us** is handled much like a proper request, except for the exception playing the role of the request. In this case we **forward it into the generator** on lines [L47-L49](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L47-L49) and await its response. We explicitly **advance** the traceback one frame up, in order to indicate the **source of the exception within the generator**.

Finally the `GeneratorExit` is handled on lines [L42-L45](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L42-L45) and closes the generator.

Updates: clarified exception propagation

Pull Request resolved: pytorch#49017

Reviewed By: izdeby

Differential Revision: D25567796

Pulled By: albanD

fbshipit-source-id: 801577cccfcb2b5e13a08e77faf407881343b7b0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed enhancement Not as big of a feature, but technically not a bug. Should be easy to fix Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants