Replace several Ops by OpFromGraph#2110
Conversation
c538541 to
f02d376
Compare
8fd0886 to
6a5a28f
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
very cool, seems to simplify things a lot.
If I understand well, this is new SymbolicTensorOp should be used in cases when I just want to "mark" a block of computation for rewrites to reason about jointly? What is the use-case for standard OpFromGraph after this?
| np.testing.assert_allclose(gyv, 0.0) | ||
|
|
||
|
|
||
| def test_xlogy_as_xlogx(): |
There was a problem hiding this comment.
probably don't need this one
There was a problem hiding this comment.
It's good to confirm we're not being bitten by that repeated input in the inner graph gotcha.
There was a problem hiding this comment.
but the tests can certainly be merged
This is just syntactic sugar of top of OFG with nicer API imo. If you want it to work with different types (vector, matrix, etc) you give a generative function for it. The big difference is this function lives in the Op class, whereas before it used to be a distinct floating helper. This is pretty much the thing that SymbolicRV evolved into, so it's the product of some learning of what's needed in practice. As I said this could be merged into OFG (or one day replace OFG), but it already has so much complex API I didn't want to burden it even more. |
That and custom gradients. It's the reason why xlogy and xlogyp1 are not just helpers that build the forward graph |
Adds a `print_inner_graphs` parameter to `debugprint()` with three modes: - "auto" (default): show inner graphs except for SymbolicOp - True: show all inner graphs - False: hide all inner graphs This reduces noise when printing graphs that contain SymbolicOps, whose inner graphs are implementation details rather than user logic.
Related to #1210
These Ops are needed to hide the inner graph for easier graph rewriting, or to control autodiff. There's no need to implement new primitives for this functionality. Removing primitive Ops makes maintaining multiple backends easier, and benefits from optimizations we do do the core primitives.
I had postponed softmax because it had a custom C kernel, but with the change from it being the default I think it's fine. Numba is doing the same thing it was doing before (but less code), and libraries that may prefer direct dispatch like JAX/PyTorch/MLX continue to do so.
I adde a new SymbolicOp subclass, much in line with the prototype SymbolicRandomVariable in PyMC. It has a method to build the inner graph on demant from input types, and a simpler interface than OFG (OFG is a bit too much already, so I didn't want to extend it further, we can reconsider this).
Next we could tackle #1221, which would allow us to still work with them as objects (needed for dispatching in PyMC), but allows us to have less real Ops, and to also inline, which can be great for memory optimization / rewrites. (E.g., a MvNormal L -> covariance -> L across the RV boundary doesn't get optimized now), plus memory opt and all that. That's left for another PR though