-
Notifications
You must be signed in to change notification settings - Fork 74.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating cuBLASLt into XLA #55518
Conversation
BlasPlansCompatibleType(element_type)) { | ||
TF_RETURN_IF_ERROR( | ||
DoBlasPlansAutotune(stream, instr, allocator, gemm_config)); | ||
return {se::blas::kNoAlgorithm}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the significance of returning kNoAlgorithm
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cuBLASLt autotuner operates on se::blas::AlgorithmConfig
instead of se::blas::AlgorithmType
used in non-cuBLASLt autotuning. As such, the result of cuBLASLt autotuning is incompatible with the return value of DoGemmAutotune
and the se::blas::kNoAlgorithm
dummy value is returned instead. The outcome of cuBLASLt autotuning is stored in the instance of BlasPlansAutotuneCacheSingleton
.
GemmCacheKey key = | ||
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(), | ||
instr->shape(), gemm_config.SerializeAsString()); | ||
if (stream->parent()->SupportsBlasPlans() && config.use_cublaslt && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe slightly pedantic, but we ought to check the flag first, lest SupportBlasPlans
has some kind of side effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SupportsBlasPlans()
is introduced in this PR and has no side effects. It returns true if the CUDA version is greater than or equal to 11, and false otherwise.
@philipphack Can you please resolve conflicts? Thank you! |
PiperOrigin-RevId: 441800149
FYI @philipphack, this was merged. I'm not sure why this PR hasn't been updated. |
Seems auto-merge is not happening but the changes are merged into master now, so we can close this. Thank you for the PR. |
Adds support for the cuBLASLt library for GEMM operations to XLA. The library can be activated by setting the XLA flag
xla_gpu_enable_cublaslt=true
.@SandSnip3r can you run the test before merging?