Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rf.Cond and rf.Loop for eager frameworks #1282

Closed
albertz opened this issue Mar 23, 2023 · 4 comments
Closed

rf.Cond and rf.Loop for eager frameworks #1282

albertz opened this issue Mar 23, 2023 · 4 comments
Assignees

Comments

@albertz
Copy link
Member

albertz commented Mar 23, 2023

I want to collect some thoughts on possible options to implement rf.Cond and rf.Loop (like RC nn.Cond and nn.Loop) for the RETURNN frontend (import returnn.frontend as rf, #1120). This is not straight-forward for eager-mode frameworks where it should execute directly.

  • Rewrite Python AST, based on current RC API. But when? When rf.Cond is called, is it too late? Must use a func decorator?
    Pro:

    • Code looks simple.
    • Can keep existing RC code.

    Con:

    • Might be unintuitive. There would be some magic involved.
  • Redesign the API of Cond and Loop to support both cases (eager-mode and graph-mode) well?
    Pro:

    • Still Pythonic, quite straightforward to implement.

    Con:

    • Breaks some existing RC code.
    • Maybe too annoying to use? Code looks ugly?
    • If code is tested and works in an eager-mode framework, no good guarantee that it also works in an graph-based framework? Unclear.

    Design:

    with rf.Cond(condition) as cond:
      if cond.true_branch():
        ...
        cond.result = ...
      if cond.false_branch():
        ...
        cond.result = ...
    y = cond.result
    
    with rf.Loop() as loop:
      while loop.iteration():
        ...
  • In the backend, have an extra check for the control flow context in all functions, and ignore execution when not in right scope.
    Pro:

    • Probably can keep existing RC code.

    Con:

    • Makes code in all backend functions, for all eager-mode backends more complex.
    • Not sure if it really works in all cases.
    • When native framework code (e.g. pure PT code) is mixed with RF code, this will not work correctly for the native code. E.g. pure PT code would be executed in any case.
  • Use a similar interface like tf.cond and tf.while_loop, using function callbacks.
    Pro:

    • Should be straightforward to implement.

    Contra:

    • Too verbose, too annoying to use.
@patrick-wilken
Copy link
Contributor

So the current interface in returnn_common is this, right?

with Cond(cond) as cond_obj:
    cond_obj.true = mod_true_case(x)
    cond_obj.false = mod_false_case(x)
    y = cond_obj.result

(And the whole point is to execute only one of the graphs defined in mod_true_case / mod_false_case.)

I haven't used it myself, so this is already not totally clear to me. From the code I see that entering the context manager and assigning to true and false will trigger switching the "name_ctx", which will collect separate subnetworks created between enter and true / between true and false, right?
Isn't this already pretty specific to network dict construction? For other graph based frameworks, e.g. if x is a tf.Tensor, wouldn't running mod_true_case(x) already create ops that we want to be inside the tf.cond?

So what about doing it similar to tf.cond itself?:

x = ...
def true_fn():
    return mod_true_case(x)

def false_fn():
    return mod_false_case(x)

y = rf.Cond(cond, true_fn, false_fn)

Was this considered to be too verbose for return_common? Or am I missing something about this name_ctx? At least this would work for eager frameworks too, you can just call one of the functions right away.

But this is an important issue. If people don't care for the graph based case and writing

y = mod_true_case(x) if cond else mod_false_case(x)

in eager mode works, then they will probably do it. Or if we make it impossible they will consider not using the frontend at all because pure PyTorch is simpler. So there is the risk when writing the frontend framework-independent that we end up with the union of all limitations of the different frameworks.

@albertz
Copy link
Member Author

albertz commented Mar 27, 2023

From the code I see that entering the context manager and assigning to true and false will trigger switching the "name_ctx", which will collect separate subnetworks created between enter and true / between true and false, right?

Right.

Isn't this already pretty specific to network dict construction? For other graph based frameworks, e.g. if x is a tf.Tensor, wouldn't running mod_true_case(x) already create ops that we want to be inside the tf.cond?

No, this is not specific to the net dict construction. This can work for any graph based frameworks. We simply need to switch to the right control flow context (TF control flow v1) or sub graph (TF control flow v2) (which corresponds to being inside the true or false branch inside tf.cond).

So what about doing it similar to tf.cond itself?

Yes, that is another possibility. I will add this to the list above. I personally found this as too verbose and ugly, that's why I did not want to use it for RC. I thought by using such context manager, it looks simpler.

But now considering this issue here, it's maybe actually a good solution, as it would be very straightforward, both for eager-based and graph-based frameworks.

But this is an important issue. If people don't care for the graph based case and writing

y = mod_true_case(x) if cond else mod_false_case(x)

If cond is a Python bool, so really a constant, then it does not matter.

If cond is a scalar Tensor, then we can just implement __bool__ and make it throw an error. This is similar as in tf.Tensor, I think. Then the user can never write it like that. Of course, the user could write:

y = mod_true_case(x) if cond.raw_tensor.numpy() else mod_false_case(x)

But once the user accesses raw_tensor, it is anyway clear that it is backend specific.

@albertz
Copy link
Member Author

albertz commented Mar 29, 2023

So, any opinions?

I currently tend to the TF-like APIs (so rf.cond, rf.while_loop), as they would be very straightforward and simple to implement, without requiring any magic. Of course, they have the downside that the user code looks more complicated.

@albertz
Copy link
Member Author

albertz commented May 10, 2023

So we implemented rf.cond and rf.while_loop (and also rf.scan, #1324) now.

@albertz albertz closed this as completed May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants