-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[nnc] Support thread level parallelism in fused kernels #63386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
25dbf98
ac73fc8
0670ad3
7e4cd4e
385846b
82d2274
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -274,15 +274,24 @@ class LLVMCodeGenImpl : public IRVisitor { | |
| } | ||
| }; | ||
|
|
||
| extern "C" { | ||
| typedef void (*ParallelCallee)(int index, int8_t* packed_data); | ||
| void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) { | ||
| void DispatchParallel( | ||
| int8_t* func, | ||
| int start, | ||
| int stop, | ||
| int8_t* packed_data) noexcept { | ||
| // TODO: preserve the func type. | ||
| ParallelCallee callee = reinterpret_cast<ParallelCallee>(func); | ||
| at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) { | ||
| for (int index = f_begin; index < f_end; index++) { | ||
| callee(index, packed_data); | ||
| } | ||
| }); | ||
| try { | ||
| ParallelCallee callee = reinterpret_cast<ParallelCallee>(func); | ||
| at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) { | ||
| for (int index = f_begin; index < f_end; index++) { | ||
| callee(index, packed_data); | ||
| } | ||
| }); | ||
| } catch (...) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kinda curious about this place: why not terminate if there's an exception? I guess the execution of the left stmts would ultimately lead to wrong results? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting point... if an exception happens here things are really screwed up, because we don't know how to unwind past llvm-generated frames. But no exceptions should be possible here, since we're just parallel-dispatching to our own kernel, which doesn't throw exceptions. So I was mainly putting the try-catch here to ensure that the compiler knew that it wouldn't need to unwind this frame. |
||
| } | ||
| } | ||
| } | ||
|
|
||
| } // namespace tensorexpr | ||
|
|
@@ -1287,6 +1296,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { | |
| module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype); | ||
| llvm::Function* dispatcher = | ||
| llvm::cast<llvm::Function>(dispatcher_callee.getCallee()); | ||
| dispatcher->addFnAttr(llvm::Attribute::NoUnwind); | ||
| irb_.CreateCall( | ||
| dispatcher, {func_value, start, stop, packed_caller_args_ptr}); | ||
| value_ = llvm::ConstantInt::get(IntTy_, 0); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function is called after
fuseAllLoopsit is possible that multiple buffers belong to the same loopnest. So, we could be repeating this loop multiple times for the same loopnest. I understand that that may not be incorrect at this point, but it could lead to bugs in future.IMO, we shouldn't be looking at output buffers and their loopnests. Instead, we should just take
root_stmtin the givenLoopNestand apply parallelization for all loopnests in that stmt. Wdyt?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems OK to me, sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess where things could get a little weird is if multiple buffers are updated at different levels of the loopnest, e.g.:
The current approach sort of gives each buffer an "independent" chance to affect the loop parallelization. That doesn't seem terrible, tbh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the more I think about it, is there really any advantage to starting with the root stmt and working down? From my POV it just makes the code a lot more complicated; the way things happen now I just get a nice vector of loops leading to a buffer and try to flatten them. If it's not flattenable, it simply fails and I give up.
Although, maybe it's not too much work to build up my own vector starting from the root. Idk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are okay with having the same set of loops being handled here for different bufs, then I have no objections to it.
Personally, I felt starting from the root_stmt might be better. We might need another API to extract all loops in the root_stmt. So, may be we can do this in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, going through all buffers would not miss parallelism opportunities like the following
i+j1cannot be parallelized because there's data dependence between iterations fory1; buti+j2can be parallelized and we should not miss it. If this is the thing we try to do here, I guess we need adistributetransformation beforeflatten.flattencurrently only handles perfectly nested loops.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the approach here will definitely miss opportunities where some nested loops are parallelizable. It's really kind of a best-effort thing to get simple elementwise fusions right, not a general solution to parallelism.