Make input output_grads to OFG pullback function consistent at runtime#2065
Conversation
output_grads to OFG pullback function consistent at runtime
|
I think there was some logic for the "hack". If you have a graph with OFG with two independent "streams" say [x, y] -> [x+1, y+1], (no interaction between each), if you get a disconnected gradient wrt to the second output, the correct behavior is to return a disconnected gradient wrt to the first input. But now even if the gradient wrt to this input was a dead end, the graph won't compile. This partitioning away is something Scan also do, there I think by feeding dummy zeros, and then ignoring them. Is the path you changed just for custom pullback? I assume the default pullback is still doing the same and working? If so fine, but maybe worth a note on the section on custom pull_back, that it won't work well for fully independent disconnected streams in the inner graph. |
|
I think I don't fully understand your comment. The issue is just about the consistency of the inputs to the inner |
|
I agree with you, the only thing I'm concerned about is that this is still working: import pytensor
import pytensor.tensor as pt
from pytensor.gradient import DisconnectedType, disconnected_type
def pullback(inputs, outputs, output_grads):
x, y = inputs
dout0, dout1 = output_grads
if isinstance(dout0.type, DisconnectedType):
dx = disconnected_type()
else:
dx = dout0 * 2 * x
if isinstance(dout1.type, DisconnectedType):
dy = disconnected_type()
else:
dy = dout1 * 3 * y **2
return [dx, dy]
x = pt.scalar("x")
y = pt.scalar("y")
# ofg has two streams that are completely independent
# test we can get gradient from a single output wrt to the connected input
op = pytensor.OpFromGraph([x, y], [x**2, y**3])
op = pytensor.OpFromGraph([x, y], [x**2, y**3], pullback=pullback)
out0, out1 = op(x, y)
grad_out0_wrt_x = pt.grad(out0, x)
fn = pytensor.function([x, y], grad_out0_wrt_x)
assert fn(2.0, 2.0) == 4.0I don't think we are testing this case. Can you confirm and add? There's another test with disconnected type but with related outputs |
Closes #2064