-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[nnc] Support thread level parallelism in fused kernels #63386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 82d2274 (more details on the Dr. CI page):
🕵️ 4 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382) [ghstack-poisoned]
Pull Request resolved: #63386 Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382/) ghstack-source-id: 136005946
Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382) [ghstack-poisoned]
Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382) [ghstack-poisoned]
Pull Request resolved: #63386 Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382/) ghstack-source-id: e2b4dde
|
@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382) [ghstack-poisoned]
|
@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
| for (int64_t i = loops.size(); i > 0; i--) { | ||
| auto const& loop = loops[i - 1]; | ||
| if (auto stop = to<IntImm>(loop->stop())) { | ||
| grainSize *= stop->value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is assuming the loops are normalized at this point. While that might be true mostly, I'm not sure that is guaranteed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think anything's likely to un-normalize the loops before this point, but maybe I should simplify (and do stop-start) just to be safe. I guess the worst that happens is we miss a parallelization opportunity.
| template <typename Bufs> | ||
| static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) { | ||
| for (auto const& buf : bufs) { | ||
| auto loops = l.getLoopStmtsFor(buf); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function is called after fuseAllLoops it is possible that multiple buffers belong to the same loopnest. So, we could be repeating this loop multiple times for the same loopnest. I understand that that may not be incorrect at this point, but it could lead to bugs in future.
IMO, we shouldn't be looking at output buffers and their loopnests. Instead, we should just take root_stmt in the given LoopNest and apply parallelization for all loopnests in that stmt. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems OK to me, sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess where things could get a little weird is if multiple buffers are updated at different levels of the loopnest, e.g.:
for i:
y1[] = ...
for j:
y2[] = ...
The current approach sort of gives each buffer an "independent" chance to affect the loop parallelization. That doesn't seem terrible, tbh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the more I think about it, is there really any advantage to starting with the root stmt and working down? From my POV it just makes the code a lot more complicated; the way things happen now I just get a nice vector of loops leading to a buffer and try to flatten them. If it's not flattenable, it simply fails and I give up.
Although, maybe it's not too much work to build up my own vector starting from the root. Idk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are okay with having the same set of loops being handled here for different bufs, then I have no objections to it.
Personally, I felt starting from the root_stmt might be better. We might need another API to extract all loops in the root_stmt. So, may be we can do this in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, going through all buffers would not miss parallelism opportunities like the following
for i
for j1
y1 = ... [data dependence exists between iterations]
for j2
y2 = ... [no data dependence between iterations]
i+j1 cannot be parallelized because there's data dependence between iterations for y1; but i+j2 can be parallelized and we should not miss it. If this is the thing we try to do here, I guess we need a distribute transformation before flatten. flatten currently only handles perfectly nested loops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the approach here will definitely miss opportunities where some nested loops are parallelizable. It's really kind of a best-effort thing to get simple elementwise fusions right, not a general solution to parallelism.
torch/csrc/jit/tensorexpr/kernel.cpp
Outdated
| continue; | ||
| } | ||
| // Try to flatten the outer loops and parallelize them if successful. | ||
| For* flattened = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: ForPtr please ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| callee(index, packed_data); | ||
| } | ||
| }); | ||
| } catch (...) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda curious about this place: why not terminate if there's an exception? I guess the execution of the left stmts would ultimately lead to wrong results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting point... if an exception happens here things are really screwed up, because we don't know how to unwind past llvm-generated frames. But no exceptions should be possible here, since we're just parallel-dispatching to our own kernel, which doesn't throw exceptions. So I was mainly putting the try-catch here to ensure that the compiler knew that it wouldn't need to unwind this frame.
Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382) [ghstack-poisoned]
|
@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@bertmaher merged this pull request in d6d86ef. |
|
This pull request has been reverted by 37d60c0. |
Stack from ghstack:
Differential Revision: D30360382