New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] Cublaslt fp8 matmul restriction work-around #59515
[XLA] Cublaslt fp8 matmul restriction work-around #59515
Conversation
ea63fee
to
706ad9a
Compare
@@ -646,7 +642,40 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |||
// and B are later exchanged, and B is transposed here instead. | |||
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration. | |||
TF_ASSIGN_OR_RETURN(bool is_col_major, | |||
OutputIsColumnMajor(instr, gemm_backend_config)); | |||
MatrixIsColumnMajor(instr, gemm_backend_config, "d")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatrixIsColumnMajor
doesn't exist. Did you forget to commit that change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Just added.
// Bitcast the operands to realign their logical and physical dimensions. | ||
std::vector<int64_t> a_dim_order; | ||
a_dim_order.reserve(a_dims.size()); | ||
absl::Span<const int64_t> a_minor_to_major = | ||
a->shape().layout().minor_to_major(); | ||
for (int i = 0; i < a_dims.size(); ++i) { | ||
a_dim_order.emplace_back( | ||
absl::c_find(a_minor_to_major, | ||
is_col_major ? i : a_dims.size() - i - 1) - | ||
a_minor_to_major.begin()); | ||
} | ||
a = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::PermuteDimensions(a_dim_order, a->shape()), a, a_dim_order)); | ||
|
||
std::vector<int64_t> b_dim_order; | ||
b_dim_order.reserve(b_dims.size()); | ||
absl::Span<const int64_t> b_minor_to_major = | ||
b->shape().layout().minor_to_major(); | ||
for (int i = 0; i < b_dims.size(); ++i) { | ||
b_dim_order.emplace_back( | ||
absl::c_find(b_minor_to_major, | ||
is_col_major ? i : b_dims.size() - i - 1) - | ||
b_minor_to_major.begin()); | ||
} | ||
b = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::PermuteDimensions(b_dim_order, b->shape()), b, b_dim_order)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this code block. Can you clarify what it means to "realign their logical and physical dimensions"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider the following case: {128,32,64}[2,0,1]. Assuming output is row major. This code block makes it become,
{128,64,32}[2,1,0]. Essentially align physical dimensions(whatever in the bracket) to [N-1,N-2,...,0], as well as the logical dims in curly brackets.
// Apply necessary transposes to accommodate canonicalize matmul(lhs and rhs | ||
// contracting dims are 1 and 0). Also assuming transpose folding pass later | ||
// will remove duplcated transposes. The last transpose is required by | ||
// cublas fp8 matmul restriction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please confirm my understanding of what the issue. IIUC, right now (without this PR), gemm_rewriter transposes the LHS of the matmul. This is to negate the tranpose that MatmulPlan
does to meet the cuBLAS requirement that the LHS must be transposed.
The issue with this approach (that this PR is addressing), is that this process doesn't take into account the layouts of the LHS and the RHS (although it does check the layout of the output). According to this Github comment
The
cublasLtOrder_t
is provided for convenience. Actual matrix layout is figured out internally by also takingcublasOperation_t
into account. So you can construct multiple valid combinations that all will work as long as they map to column-major TN.
The code currently assumes that the LHS must be transposed and the RHS must not be transposed, but in reality, you must also take into account the order (row-major vs column-major).
I'm still trying to understand this PR, but perhaps an alternative way to fix this would be the following: Remove the transpose in Matmul
plan, then ensure the LHS is "effectively" row-major and the RHS is "effectively" column-major ("effectively" in this case meaning what MatrixIsColumnMajor
returns as opposed to what the layout indicates, which can be different depending on the contracting dimensions). You can do by using MatrixIsColumnMajor
to determine the order of the LHS and RHS, and tranposing the LHS if the LHS is not row-major and transposing the RHS if the RHS is not column major. This would avoid the large number of if-statements below.
Also can you add a test to tensorflow/compiler/xla/tests/dot_operation_test.cc
, similar to the existing FP8 tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #59515 (comment)
/CC @philipphack |
@reedwm This is largely independent of the requirement of transposing A but not B. The column-major/row-major layout returned by GemmConfig::For doesn’t fully describe the configuration of a GEMM. Not considering batch dimensions, A and B each can have two possible contracting dimensions (0, 1) and two possible minor-to-major orders ({0,1}, {1,0}). The idea is to bitcast/transpose the operands into the configuration supported by cuBLASLt (A and B have contracting dimensions 1 and 0 and both have {1, 0} orders). Since this only affects FP8 GEMMs, I think it can make sense to introduce this here without changing the general logic. |
OutputIsColumnMajor(instr, gemm_backend_config)); | ||
MatrixIsColumnMajor(instr, gemm_backend_config, "d")); | ||
|
||
TF_ASSIGN_OR_RETURN(bool a_is_col_major, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
TF_ASSIGN_OR_RETURN(bool a_is_col_major, | ||
MatrixIsColumnMajor(instr, gemm_backend_config, "a")); | ||
|
||
TF_ASSIGN_OR_RETURN(bool b_is_col_major, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This a/b_is_col_major are needed to determined what transposes should be applied.
|
||
// Plain transpose on a or b. Plain transposes a matrix by permuting its | ||
// dimension without changing storage order. | ||
auto plain_transpose = [&](const char matrix_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling this with the operand and its dimensions instead of the name would make the if statement redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
b200202
to
e0e67a4
Compare
@philipphack @wenscarl @reedwm what is the next step here?
Is that still the plan? |
I confirmed the non-TN invocation works (the NN invocation, where neither input is transposed). I'm working on modifying this PR to run with NN, then I'll share it and ask if you think it's clearer. @philipphack @wenscarl do you plan on adding more test cases? If not I can also add some. |
Eventually can we please streamline and shorten the tests by perhaps parametrizing them? This would also help with readability. https://google.github.io/googletest/reference/testing.html#TEST_P |
4bd7703
to
adba8d7
Compare
Imported from GitHub PR tensorflow/tensorflow#59515 As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py). A remaining restriction is the batch dimension still needs to be a leading dimension. Copybara import of the project: -- 77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>: Workaround cublasLt fp8 matmul restrictions. -- 706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>: Work around cublasLt fp8 matmul restrictions -- 8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>: Add MatrixIsColumnMajor -- e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>: Remove dead code -- cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>: Add gemm_rewrite test -- 460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>: Cover 32 cases -- 6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>: Abbreviate branches logics -- 64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>: Parameterized test -- 6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>: auto -> type names -- 96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>: Update comments -- 07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>: vector -> array -- 3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>: fix typo -- 50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>: Convert to NN config and add batched matmul tests Merging this change closes #59515 FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#59515 from wenscarl:cublaslt_fp8_matmul_war bff4bc682d5e0bc200e6261317ba1499960c3e45 PiperOrigin-RevId: 512191684
For a[b]_bitcast, MakeShape -> MakeShapeWithDenseLayout.
Imported from GitHub PR tensorflow/tensorflow#59515 As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py). A remaining restriction is the batch dimension still needs to be a leading dimension. Copybara import of the project: -- 77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>: Workaround cublasLt fp8 matmul restrictions. -- 706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>: Work around cublasLt fp8 matmul restrictions -- 8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>: Add MatrixIsColumnMajor -- e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>: Remove dead code -- cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>: Add gemm_rewrite test -- 460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>: Cover 32 cases -- 6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>: Abbreviate branches logics -- 64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>: Parameterized test -- 6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>: auto -> type names -- 96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>: Update comments -- 07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>: vector -> array -- 3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>: fix typo -- 50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>: Convert to NN config and add batched matmul tests Merging this change closes #59515 PiperOrigin-RevId: 512573287
PiperOrigin-RevId: 512573287
Imported from GitHub PR tensorflow/tensorflow#59515 As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py). A remaining restriction is the batch dimension still needs to be a leading dimension. Copybara import of the project: -- fe390a9c6c94b25d0e78ebced21f8d74bb2de522 by shuw <shuw@nvidia.com>: Update gemm_rewrite tests -- de695a65cc10678afaa8a486443055f03db50f15 by wenscarl <shuw@nvidia.com>: Preserve physical layout. For a[b]_bitcast, MakeShape -> MakeShapeWithDenseLayout. Merging this change closes #59515 PiperOrigin-RevId: 512599386
As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at here.
A remaining restriction is the batch dimension still needs to be a leading dimension.