Skip to content

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Nov 16, 2018

Fixes #12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms; previously, each JIT lstm cell used 2 fused kernels. Now, it only uses one fused kernel (which is how many kernels cudnn uses).

Explanation:

Let f, g, h be fusible ops.

x = f(v, w)
z = g(x, y)
a, b = chunk(z)
c = h(a, b)

becomes (before this PR):

x = f(v, w)
x', y' = broadcast_tensors([x, y])
ax, bx = chunk(x')
ay, by = chunk(y')
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)

The graph fuser then puts g, g, and h into one FusionGroup and is unable
to move x = f(v, w) into the FusionGroup.

This PR lets the graph fuser move x = f(v, w) into the FusionGroup.
It does this by abstracting the broadcast_tensors + multiple chunk nodes
into one intermediate prim::BroadcastingChunk[chunks, dim] node.

A BroadcastingChunk[chunks, dim](*inputs) node is equivalent to:

  • broadcasting all of *inputs
  • chunk-ing each broadcasted input into chunks chunks along dim dim.

Abstracting the broadcasting chunk behavior away, it is now a lot easier
for the graph fuser to move (broadcast + chunk) past an operation. After
this PR, the above graph becomes:

x = f(v, w)
ax, bx, ay, by = BroadcastingChunk(x, y)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)

Now, to move x = f(v, w) after the BroadcastingChunk, one just needs
to add f's operands to the BroadcastingChunk:

ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
ax = f(av, aw)
by = f(bv, bw)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)

cc @apaszke @mruberry @zdevito

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Nov 16, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome 👏

Would be good to expand our test suite a bit in this case. Ultimately we want all ways to parenthesize this expression (possibly with permutation) x.mm(w_ih) + h.mm(w_hh) + b_ih + b_hh to get fused correctly into some operations following the chunk, but we're only checking one at the moment.

I want to think about this PR a bit more, but I'm pretty confident that it's good and can land.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Fixes pytorch#12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms.

Explanation:

Let f, g, h be fusible ops.
```
x = f(v, w)
z = g(x, y)
a, b = chunk(z)
c = h(a, b)
```
becomes (before this PR):
```
x = f(v, w)
x', y' = broadcast_tensors([x, y])
ax, bx = chunk(x')
ay, by = chunk(y')
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```
The graph fuser then puts g, g, and h into one FusionGroup and is unable
to move `x = f(v, w)` into the FusionGroup.

This PR lets the graph fuser move `x = f(v, w)` into the FusionGroup.
It does this by abstracting the broadcast_tensors + multiple chunk nodes
into one intermediate prim::BroadcastingChunk[chunks, dim] node.

A BroadcastingChunk[chunks, dim](*inputs) node is equivalent to:
- broadcasting all of *inputs
- chunk-ing each broadcasted input into `chunks` chunks along dim `dim`.

Abstracting the broadcasting chunk behavior away, it is now a lot easier
for the graph fuser to move (broadcast + chunk) past an operation. After
this PR, the above graph becomes:
```
x = f(v, w)
ax, bx, ay, by = BroadcastingChunk(x, y)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```
Now, to move `x = f(v, w)` after the BroadcastingChunk, one just needs
to add f's operands to the BroadcastingChunk:
```
ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
ax = f(av, aw)
by = f(bv, bw)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```
- Minor fixes
- Deduplicate inputs to BroadcastingChunk
- test deduplication of inputs
- test permutations of gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih +
  b_hh and assert all of them result in one FusionGroup (as opposed to
  multiple)
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor Author

zou3519 commented Nov 26, 2018

Windows test looks flaky (I see it on other PRs as well).

zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 26, 2018
Summary:
Fixes #12290. Also speeds up JIT LSTM forward pass from 8.8ms to 7.8ms; previously, each JIT lstm cell used 2 fused kernels. Now, it only uses one fused kernel (which is how many kernels cudnn uses).

Explanation:

Let f, g, h be fusible ops.
```
x = f(v, w)
z = g(x, y)
a, b = chunk(z)
c = h(a, b)
```
becomes (before this PR):
```
x = f(v, w)
x', y' = broadcast_tensors([x, y])
ax, bx = chunk(x')
ay, by = chunk(y')
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```
The graph fuser then puts g, g, and h into one FusionGroup and is unable
to move `x = f(v, w)` into the FusionGroup.

This PR lets the graph fuser move `x = f(v, w)` into the FusionGroup.
It does this by abstracting the broadcast_tensors + multiple chunk nodes
into one intermediate `prim::BroadcastingChunk[chunks, dim]` node.

A `BroadcastingChunk[chunks, dim](*inputs)` node is equivalent to:
- broadcasting all of *inputs
- chunk-ing each broadcasted input into `chunks` chunks along dim `dim`.

Abstracting the broadcasting chunk behavior away, it is now a lot easier
for the graph fuser to move (broadcast + chunk) past an operation. After
this PR, the above graph becomes:
```
x = f(v, w)
ax, bx, ay, by = BroadcastingChunk(x, y)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```
Now, to move `x = f(v, w)` after the BroadcastingChunk, one just needs
to add f's operands to the BroadcastingChunk:
```
ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
ax = f(av, aw)
by = f(bv, bw)
a = g(ax, ay)
b = g(bx, by)
c = h(a, b)
```

cc apaszke mruberry zdevito
Pull Request resolved: pytorch/pytorch#14055

Differential Revision: D13159259

Pulled By: zou3519

fbshipit-source-id: 134e9e645c950384d9be6a06a883a10e17a73d7d
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[jit] graph fuser doesn't look past broadcasts when fusing

4 participants