-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Allow graph fuser to move chunks past multiple nodes. #14055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome 👏
Would be good to expand our test suite a bit in this case. Ultimately we want all ways to parenthesize this expression (possibly with permutation) x.mm(w_ih) + h.mm(w_hh) + b_ih + b_hh
to get fused correctly into some operations following the chunk, but we're only checking one at the moment.
I want to think about this PR a bit more, but I'm pretty confident that it's good and can land.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes pytorch#12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms. Explanation: Let f, g, h be fusible ops. ``` x = f(v, w) z = g(x, y) a, b = chunk(z) c = h(a, b) ``` becomes (before this PR): ``` x = f(v, w) x', y' = broadcast_tensors([x, y]) ax, bx = chunk(x') ay, by = chunk(y') a = g(ax, ay) b = g(bx, by) c = h(a, b) ``` The graph fuser then puts g, g, and h into one FusionGroup and is unable to move `x = f(v, w)` into the FusionGroup. This PR lets the graph fuser move `x = f(v, w)` into the FusionGroup. It does this by abstracting the broadcast_tensors + multiple chunk nodes into one intermediate prim::BroadcastingChunk[chunks, dim] node. A BroadcastingChunk[chunks, dim](*inputs) node is equivalent to: - broadcasting all of *inputs - chunk-ing each broadcasted input into `chunks` chunks along dim `dim`. Abstracting the broadcasting chunk behavior away, it is now a lot easier for the graph fuser to move (broadcast + chunk) past an operation. After this PR, the above graph becomes: ``` x = f(v, w) ax, bx, ay, by = BroadcastingChunk(x, y) a = g(ax, ay) b = g(bx, by) c = h(a, b) ``` Now, to move `x = f(v, w)` after the BroadcastingChunk, one just needs to add f's operands to the BroadcastingChunk: ``` ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w) ax = f(av, aw) by = f(bv, bw) a = g(ax, ay) b = g(bx, by) c = h(a, b) ```
- Minor fixes - Deduplicate inputs to BroadcastingChunk - test deduplication of inputs - test permutations of gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh and assert all of them result in one FusionGroup (as opposed to multiple)
41e37fe
to
9318ec2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Windows test looks flaky (I see it on other PRs as well). |
Summary: Fixes #12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms; previously, each JIT lstm cell used 2 fused kernels. Now, it only uses one fused kernel (which is how many kernels cudnn uses). Explanation: Let f, g, h be fusible ops. ``` x = f(v, w) z = g(x, y) a, b = chunk(z) c = h(a, b) ``` becomes (before this PR): ``` x = f(v, w) x', y' = broadcast_tensors([x, y]) ax, bx = chunk(x') ay, by = chunk(y') a = g(ax, ay) b = g(bx, by) c = h(a, b) ``` The graph fuser then puts g, g, and h into one FusionGroup and is unable to move `x = f(v, w)` into the FusionGroup. This PR lets the graph fuser move `x = f(v, w)` into the FusionGroup. It does this by abstracting the broadcast_tensors + multiple chunk nodes into one intermediate `prim::BroadcastingChunk[chunks, dim]` node. A `BroadcastingChunk[chunks, dim](*inputs)` node is equivalent to: - broadcasting all of *inputs - chunk-ing each broadcasted input into `chunks` chunks along dim `dim`. Abstracting the broadcasting chunk behavior away, it is now a lot easier for the graph fuser to move (broadcast + chunk) past an operation. After this PR, the above graph becomes: ``` x = f(v, w) ax, bx, ay, by = BroadcastingChunk(x, y) a = g(ax, ay) b = g(bx, by) c = h(a, b) ``` Now, to move `x = f(v, w)` after the BroadcastingChunk, one just needs to add f's operands to the BroadcastingChunk: ``` ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w) ax = f(av, aw) by = f(bv, bw) a = g(ax, ay) b = g(bx, by) c = h(a, b) ``` cc apaszke mruberry zdevito Pull Request resolved: pytorch/pytorch#14055 Differential Revision: D13159259 Pulled By: zou3519 fbshipit-source-id: 134e9e645c950384d9be6a06a883a10e17a73d7d
Fixes #12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms; previously, each JIT lstm cell used 2 fused kernels. Now, it only uses one fused kernel (which is how many kernels cudnn uses).
Explanation:
Let f, g, h be fusible ops.
becomes (before this PR):
The graph fuser then puts g, g, and h into one FusionGroup and is unable
to move
x = f(v, w)
into the FusionGroup.This PR lets the graph fuser move
x = f(v, w)
into the FusionGroup.It does this by abstracting the broadcast_tensors + multiple chunk nodes
into one intermediate
prim::BroadcastingChunk[chunks, dim]
node.A
BroadcastingChunk[chunks, dim](*inputs)
node is equivalent to:chunks
chunks along dimdim
.Abstracting the broadcasting chunk behavior away, it is now a lot easier
for the graph fuser to move (broadcast + chunk) past an operation. After
this PR, the above graph becomes:
Now, to move
x = f(v, w)
after the BroadcastingChunk, one just needsto add f's operands to the BroadcastingChunk:
cc @apaszke @mruberry @zdevito