Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new Loop and Scan operators #191

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 10, 2023

Related to #189

This PR implements a new low level Loop Op which can be easily transpiled to Numba (the Python perform method takes 9 lines, yay to not having to support C in the future).

It also implements a new higher level Scan Op which returns as outputs the last states + intermediate states of a looping operation. This Op cannot be directly evaluated, and must be rewritten as a Loop Op in Python/Numba backends. For the JAX backend it's probably fine to transpile directly from this representation into a lax.scan as the signatures are pretty much identical. That was not done in this PR.

The reason for the two types of outputs, is that they are useful in different contexts. Final states are sometimes all one needs, whereas intermediate states are generally needed for back propagation (not implemented yet). This allows us to choose which one (or both) of the outputs we want during compilation, without having to do complicated graph analysis.

The existing save_mem_new_scan is used to convert a general scan into a loop that only returns the last computed state. It's... pretty complicated (although it also covers cases where more than 1 but less than all steps being requested, but OTOH it can't handle while loops #178):

def save_mem_new_scan(fgraph, node):

Taking that as a reference I would say the new conversion rewrite from Scan to Loop is much much simpler. Most of it is boilerplate code for defining the right trace inputs and new FunctionGraph


Both Ops expect a FunctionGraph as input. This should probably be created by a user-facing helper that accepts a callable like scan does now. That was not done yet, as I first wanted to discuss the general design. Done

Design issues

1. The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.

  1. Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a return_rng_update to __call__, so that it doesn't hide the next rng state output?

  2. Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that? IfElse is one option, but perhaps it would be nice to represent it in the same Loop Op?

  3. What do we want to do in terms of inplacing optimizations?

TODO

If people are on board with the approach

  • Implement Numba dispatch
  • Implement JAX dispatch
  • Implement L_op and R_op
  • Implement friendly user facing functions
  • Decide on which meta-parameters to preserve (mode, truncate_gradient, reverse and so on)
  • Add rewrite that replaces trace[-1] by the first set of outputs (final state). That way we can keep the old API, while retaining the benefit of doing while Scans without tracing when it's not needed.

assert input_state.type == output_state.type


class Loop(Op):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Add mixin HasInnerGraph so that we can see the inner graph in debug_print

@ricardoV94 ricardoV94 force-pushed the looping branch 9 times, most recently from 76a9b4c to f2a2c03 Compare January 11, 2023 15:27
pytensor/loop/basic.py Outdated Show resolved Hide resolved
@aseyboldt
Copy link
Member

aseyboldt commented Jan 11, 2023

The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.

Wouldn't a fill loop look something like this?

state = (pt.scalar(0), pt.empty(shape, dtype), rng)
def update(idx, values, rng):
    value, rng = rng.normal()  # not exactly the api...
    values = pt.set_subtensor(values[idx], value)
    return (idx + 1, values, rng, idx < maxval)

(and very much need inplace rewrites for good performance...)

Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a return_rng_update to call, so that it doesn't hide the next rng state output?

Good question...
Don't know either :-)

Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that? IfElse is one option, but perhaps it would be nice to represent it in the same Loop Op?

I think one rewrite that get's easier with the if-else-do-while approach would be loop invariant code motion. Let's say we have a loop like

x = bigarray...
if not_empty:
    val = 0
    do:
        val = (val + x.sum()) ** 2
    while val < 10

# rewrite to
x = bigarray...
if not_empty:
    val = 0
    x_sum = x.sum()
    do:
        val = (val + x_sum) ** 2
    while val < 10

we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?

What do we want to do in terms of inplacing optimizations?

Well, I guess we really need those :-)
I'm thinking it might be worth it to copy the initial state, and then donate the state to the inner function? And I guess we need to make sure rewrites are actually running on inner graphs as well...

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2023

we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?

Why can't we move it even if it's empty? Sum works fine. Are you worried about Ops that we know will fail with empty inputs?

About the filling Ops, yeah I don't see it as a problem anymore. Just felt awkward to create the dummy input when translating from scan to loop. I am okay with it now

@aseyboldt
Copy link
Member

That would change the behavior. If we move it out and don't prevent it from being executed, things could fail for instance if there's an assert somewhere, or some other error happens during it's evaluation. Also, it could be potentially very costly (let's say "solve an ode").

(somehow I accidentally edited your comment instead of writing a new one, no clue how, but fixed now)

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2023

In my last commit, sequences are demoted from special citizens to just another constant input in the ScanOp. The user facing helper creates the right graph with indexing that is passed to the user provided function.

I have reverted converting the constant inputs to dummies before calling the user function, which allows the example in the jacobian documentation to work, including the one that didn't work before (because both are now equivalent under the hood :))

https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html#computing-the-jacobian

I reverted too much, and I still need to pass dummy inputs as the state variables, since it doesn't make sense for the user function to introspect the graph beyond the initial state (since it's only valid for the initial state)

@ricardoV94 ricardoV94 force-pushed the looping branch 2 times, most recently from 7bcd42c to 6c953b3 Compare January 13, 2023 11:17
return last_states[1:], traces[1:]


def map(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about subclassing Scan into

  • Map(Scan)
  • Reduce(Scan)
  • Filter(Scan)

It will be easier to dispatch into optimized implementations

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that later, not convinced we need that yet

if init_state is None:
# next_state may reference idx. We replace that by the initial value,
# so that the shape of the dummy init state does not depend on it.
[next_state] = clone_replace(
Copy link
Member

@ferrine ferrine Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not graph_replace or using memo for FunctionGraph(memo={symbolic_idx: idx}) (here)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that better?

@ricardoV94
Copy link
Member Author

Added a simple JAX dispatcher, works in the few examples I tried

# explicitly triggers the optimization of the inner graphs of Scan?
update_fg = op.update_fg.clone()
rewriter = get_mode("JAX").optimizer
rewriter(update_fg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives an annoying Supervisor Feature missing warning... gotta clean that up


print(max_iters)
states, traces = jax.lax.scan(
scan_fn, init=list(states), xs=None, length=max_iters
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo: Check we are not missing performance by not having explicit sequences.

Todo: When there are multiple sequences PyTensor defines n_steps as the shortest sequence. JAX should be able to handle this, but if not we could consider not allowing sequences/n_steps with different lengths in the Pytensor scan.

Then we could pass a single shape as n_steps after asserting they are the same?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 16, 2023

I just found out about TypedLists in PyTensor. That should allow us to trace any type of Variables, including RandomTypes 🤯

Pushed a couple of commits that rely on this.

@ricardoV94 ricardoV94 force-pushed the looping branch 5 times, most recently from 5f15c5e to 32b4fb4 Compare January 20, 2023 14:53
ricardoV94 and others added 6 commits January 20, 2023 16:41
Co-authored-by: Adrian Seyboldt <adrian.seyboldt@gmail.com>
Co-authored-by: Adrian Seyboldt <adrian.seyboldt@gmail.com>
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
@codecov-commenter
Copy link

Codecov Report

Merging #191 (5bc7070) into main (958cd14) will increase coverage by 0.06%.
The diff coverage is 89.11%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #191      +/-   ##
==========================================
+ Coverage   80.03%   80.09%   +0.06%     
==========================================
  Files         170      173       +3     
  Lines       45086    45435     +349     
  Branches     9603     9694      +91     
==========================================
+ Hits        36085    36392     +307     
- Misses       6789     6818      +29     
- Partials     2212     2225      +13     
Impacted Files Coverage Δ
pytensor/compile/mode.py 84.47% <ø> (ø)
pytensor/loop/basic.py 81.44% <81.44%> (ø)
pytensor/loop/op.py 90.29% <90.29%> (ø)
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/loop.py 100.00% <100.00%> (ø)
pytensor/link/utils.py 60.30% <100.00%> (+0.12%) ⬆️
pytensor/typed_list/basic.py 89.27% <100.00%> (+0.38%) ⬆️
pytensor/link/jax/dispatch/extra_ops.py 74.62% <0.00%> (-20.90%) ⬇️
pytensor/link/jax/dispatch/shape.py 80.76% <0.00%> (-7.70%) ⬇️
pytensor/link/jax/dispatch/basic.py 79.03% <0.00%> (-4.84%) ⬇️
... and 11 more

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 23, 2023

This Discourse thread is a great reminder of several Scan design issues that are fixed here: https://discourse.pymc.io/t/hitting-a-weird-error-to-do-with-rngs-in-scan-in-a-custom-function-inside-a-potential/13151/15

Namely:

  • Going to the root to find missing non-sequences (instead of using truncated_graph_inputs
  • Gradient only works by indexing non-sequences
  • Scans are very difficult to manipulate!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants