Skip to content

Make input output_grads to OFG pullback function consistent at runtime#2065

Merged
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:ofg-pullback
Apr 21, 2026
Merged

Make input output_grads to OFG pullback function consistent at runtime#2065
jessegrabowski merged 2 commits into
pymc-devs:v3from
jessegrabowski:ofg-pullback

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

Closes #2064

@jessegrabowski jessegrabowski added bug Something isn't working gradients OpFromGraph labels Apr 20, 2026
@jessegrabowski jessegrabowski changed the title OFG pullback output_grads is consistent Make input output_grads to OFG pullback function consistent at runtime Apr 20, 2026
@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 20, 2026

I think there was some logic for the "hack". If you have a graph with OFG with two independent "streams" say [x, y] -> [x+1, y+1], (no interaction between each), if you get a disconnected gradient wrt to the second output, the correct behavior is to return a disconnected gradient wrt to the first input. But now even if the gradient wrt to this input was a dead end, the graph won't compile.

This partitioning away is something Scan also do, there I think by feeding dummy zeros, and then ignoring them. Is the path you changed just for custom pullback? I assume the default pullback is still doing the same and working?

If so fine, but maybe worth a note on the section on custom pull_back, that it won't work well for fully independent disconnected streams in the inner graph.

@jessegrabowski
Copy link
Copy Markdown
Member Author

I think I don't fully understand your comment. The issue is just about the consistency of the inputs to the inner pullback function. Because of this hack, it is impossible to predict what a user-defined pullback function should do, because you cannot know what the values of output_grads is going to be. I'm pretty confident we support DisconnectedType coming into a pullback via output_grads?

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 20, 2026

I agree with you, the only thing I'm concerned about is that this is still working:

import pytensor
import pytensor.tensor as pt
from pytensor.gradient import DisconnectedType, disconnected_type

def pullback(inputs, outputs, output_grads):
    x, y = inputs
    dout0, dout1 = output_grads
    if isinstance(dout0.type, DisconnectedType):
        dx = disconnected_type()
    else:
        dx = dout0 * 2 * x
    
    if isinstance(dout1.type, DisconnectedType):
        dy = disconnected_type()
    else:
        dy = dout1 * 3 * y **2
    return [dx, dy] 

x = pt.scalar("x")
y = pt.scalar("y")

# ofg has two streams that are completely independent
# test we can get gradient from a single output wrt to the connected input
op = pytensor.OpFromGraph([x, y], [x**2, y**3])
op = pytensor.OpFromGraph([x, y], [x**2, y**3], pullback=pullback)

out0, out1 = op(x, y)
grad_out0_wrt_x = pt.grad(out0, x)
fn = pytensor.function([x, y], grad_out0_wrt_x)
assert fn(2.0, 2.0) == 4.0

I don't think we are testing this case. Can you confirm and add? There's another test with disconnected type but with related outputs

@jessegrabowski jessegrabowski merged commit e8366d1 into pymc-devs:v3 Apr 21, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gradients OpFromGraph

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OpFromGraph with multiple outputs provides inconsistent inputs to the inner pullback function

2 participants