-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Linear: dispatch to bsr_dense_mm for half and bfloat16 #94825
Changes from 13 commits
fd4af14
2c759e0
018ba1d
5d29b30
96ae529
9f22de0
9272076
f14766f
e06462c
adee82d
01b7989
aa474cb
4875767
4eee596
25117b2
1eec63f
72af18f
32a519a
6562d75
2eee8fc
61d6e16
c4aa47a
5502b19
6101854
d45a198
ce75dec
bc7ed5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,10 @@ | |
#include <ATen/native/sparse/SparseBlasImpl.h> | ||
#include <ATen/SparseCsrTensorUtils.h> | ||
|
||
// Required for checking whether Triton kernels are available | ||
#include <torch/csrc/jit/frontend/function_schema_parser.h> | ||
nikitaved marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include <ATen/core/dispatch/Dispatcher.h> | ||
|
||
#ifndef AT_PER_OPERATOR_HEADERS | ||
#include <ATen/Functions.h> | ||
#include <ATen/NativeFunctions.h> | ||
|
@@ -12,6 +16,7 @@ | |
#include <ATen/ops/_convert_indices_from_csr_to_coo.h> | ||
#include <ATen/ops/empty_like.h> | ||
#include <ATen/ops/zeros.h> | ||
#include <ATen/ops/_triton_bsr_dense_mm.h> | ||
#endif | ||
|
||
namespace at { | ||
|
@@ -70,6 +75,31 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s | |
blocksize = {values.size(-2), values.size(-1)}; | ||
} | ||
|
||
// No stable support for ROCM in Triton yet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't there? we have a triton rocm binary as dependency for the install. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something might have changes since then, I will try it out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok! |
||
#ifndef USE_ROCM | ||
// Triton works only with blocksizes which are powers of 2. | ||
const auto is_power_of_2 = [](int64_t v) -> bool { | ||
return !(v & (v - 1)); | ||
}; | ||
|
||
// Dtype and blocksize checks for potential Triton usage. | ||
if ((strided.scalar_type() == ScalarType::Half | ||
|| strided.scalar_type() == ScalarType::BFloat16) | ||
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1]) | ||
&& (blocksize[0] >= 16) && (blocksize[1] >= 16) | ||
// lhs is retiled to (b0, b1) while rhs is to (b1, b0), | ||
// so the result is tiled to (b0, b0) and we need to make | ||
// sure that dense.size(-1) is divisible by b0. | ||
&& n % blocksize[0] == 0) { | ||
const auto triton_kernel = c10::Dispatcher::singleton() | ||
.findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm")); | ||
nikitaved marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Call Triton only if dispatch key was overwritten. | ||
if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) { | ||
return at::_triton_bsr_dense_mm_out(result, compressed, strided); | ||
} | ||
} | ||
#endif | ||
|
||
// (..., r, c) -> (..., r / b0, c / b1, b0, b1) | ||
// NOTE: this function ALWAYS creates a view upon successful execution. | ||
const auto tile_tensor = [compressed_layout]( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this to be a native function? You can define it directly in python if you never need it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, we actually need it as it is being called inside
addmm
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you call it as
at""_triton_bsr_dense_mm
you do indeed. But since you already have the op, you could also call it viatriton_kernel.call(...)
and that doesn't require a native function.Being a native function just gives you all the codegen niceties (namespaces, out= variant, etc). But if you don't need any of these, you most likely don't want it to be native as it adds more constraints (BC guarantees in particular).
This is not a big deal, just wanted to mention it in case you actually want the flexibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I am aware of the limitations and all the nice things (using
out
here, actually). But, unfortunately,linear
itself is a native function, and so isaddmm
. We could indeed wrap these in Python and call into cpp implementations if needed, but that, I guess, if simple, could be done in a follow-up...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure to follow your comment. How is it linked to linear and addmm implementations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @bdhirsh , @albanD , I will give it a shot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh , whether the schema includes an out or not, I get the following failure:
This issue is non-deterministic, and only manifests itself when several tests are being run in a group, say, with
pytest -sv test/test_sparse_csr.py -k triton
. I was not able to fail any individual test, however.EDIT: apparently, some CIs even get segfaults...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems... very strange (although the error seems more related to triton than to the dispatcher?).
@nikitaved What if you try making the dispatcher op itself not an out= op, and you implement it by allocating a tensor manually inside of the op?
That error looks pretty cryptic - I see a similar error from google search in jax around "passing a static arg by keyword"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh , unfortunately, same issue:
I did not have such issues with the native function, however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we investigate that further, or do that in a follow-up?