Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Cublaslt fp8 matmul restriction work-around #59515

Merged
merged 17 commits into from Feb 27, 2023

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Feb 1, 2023

As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at here.

A remaining restriction is the batch dimension still needs to be a leading dimension.

@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Feb 1, 2023
@google-ml-butler google-ml-butler bot requested a review from r4nt February 1, 2023 21:23
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Feb 1, 2023
@wenscarl wenscarl changed the title Cublaslt fp8 matmul war Cublaslt fp8 matmul restriction work-around Feb 1, 2023
@reedwm reedwm requested review from reedwm and removed request for r4nt February 1, 2023 22:03
@@ -646,7 +642,40 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// and B are later exchanged, and B is transposed here instead.
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration.
TF_ASSIGN_OR_RETURN(bool is_col_major,
OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config, "d"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatrixIsColumnMajor doesn't exist. Did you forget to commit that change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Just added.

Comment on lines 653 to 678
// Bitcast the operands to realign their logical and physical dimensions.
std::vector<int64_t> a_dim_order;
a_dim_order.reserve(a_dims.size());
absl::Span<const int64_t> a_minor_to_major =
a->shape().layout().minor_to_major();
for (int i = 0; i < a_dims.size(); ++i) {
a_dim_order.emplace_back(
absl::c_find(a_minor_to_major,
is_col_major ? i : a_dims.size() - i - 1) -
a_minor_to_major.begin());
}
a = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(a_dim_order, a->shape()), a, a_dim_order));

std::vector<int64_t> b_dim_order;
b_dim_order.reserve(b_dims.size());
absl::Span<const int64_t> b_minor_to_major =
b->shape().layout().minor_to_major();
for (int i = 0; i < b_dims.size(); ++i) {
b_dim_order.emplace_back(
absl::c_find(b_minor_to_major,
is_col_major ? i : b_dims.size() - i - 1) -
b_minor_to_major.begin());
}
b = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(b_dim_order, b->shape()), b, b_dim_order));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this code block. Can you clarify what it means to "realign their logical and physical dimensions"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following case: {128,32,64}[2,0,1]. Assuming output is row major. This code block makes it become,
{128,64,32}[2,1,0]. Essentially align physical dimensions(whatever in the bracket) to [N-1,N-2,...,0], as well as the logical dims in curly brackets.

Comment on lines 740 to 743
// Apply necessary transposes to accommodate canonicalize matmul(lhs and rhs
// contracting dims are 1 and 0). Also assuming transpose folding pass later
// will remove duplcated transposes. The last transpose is required by
// cublas fp8 matmul restriction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please confirm my understanding of what the issue. IIUC, right now (without this PR), gemm_rewriter transposes the LHS of the matmul. This is to negate the tranpose that MatmulPlan does to meet the cuBLAS requirement that the LHS must be transposed.

The issue with this approach (that this PR is addressing), is that this process doesn't take into account the layouts of the LHS and the RHS (although it does check the layout of the output). According to this Github comment

The cublasLtOrder_t is provided for convenience. Actual matrix layout is figured out internally by also taking cublasOperation_t into account. So you can construct multiple valid combinations that all will work as long as they map to column-major TN.

The code currently assumes that the LHS must be transposed and the RHS must not be transposed, but in reality, you must also take into account the order (row-major vs column-major).

I'm still trying to understand this PR, but perhaps an alternative way to fix this would be the following: Remove the transpose in Matmul plan, then ensure the LHS is "effectively" row-major and the RHS is "effectively" column-major ("effectively" in this case meaning what MatrixIsColumnMajor returns as opposed to what the layout indicates, which can be different depending on the contracting dimensions). You can do by using MatrixIsColumnMajor to determine the order of the LHS and RHS, and tranposing the LHS if the LHS is not row-major and transposing the RHS if the RHS is not column major. This would avoid the large number of if-statements below.

Also can you add a test to tensorflow/compiler/xla/tests/dot_operation_test.cc, similar to the existing FP8 tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reedwm
Copy link
Member

reedwm commented Feb 2, 2023

/CC @philipphack

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Feb 2, 2023
@gbaned gbaned added the comp:xla XLA label Feb 2, 2023
@philipphack
Copy link
Contributor

@reedwm This is largely independent of the requirement of transposing A but not B. The column-major/row-major layout returned by GemmConfig::For doesn’t fully describe the configuration of a GEMM. Not considering batch dimensions, A and B each can have two possible contracting dimensions (0, 1) and two possible minor-to-major orders ({0,1}, {1,0}).

The idea is to bitcast/transpose the operands into the configuration supported by cuBLASLt (A and B have contracting dimensions 1 and 0 and both have {1, 0} orders). Since this only affects FP8 GEMMs, I think it can make sense to introduce this here without changing the general logic.

OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config, "d"));

TF_ASSIGN_OR_RETURN(bool a_is_col_major,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

TF_ASSIGN_OR_RETURN(bool a_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "a"));

TF_ASSIGN_OR_RETURN(bool b_is_col_major,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This a/b_is_col_major are needed to determined what transposes should be applied.


// Plain transpose on a or b. Plain transposes a matrix by permuting its
// dimension without changing storage order.
auto plain_transpose = [&](const char matrix_name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling this with the operand and its dimensions instead of the name would make the if statement redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@pjannaty
Copy link
Contributor

@reedwm This is largely independent of the requirement of transposing A but not B. The column-major/row-major layout returned by GemmConfig::For doesn’t fully describe the configuration of a GEMM. Not considering batch dimensions, A and B each can have two possible contracting dimensions (0, 1) and two possible minor-to-major orders ({0,1}, {1,0}).

The idea is to bitcast/transpose the operands into the configuration supported by cuBLASLt (A and B have contracting dimensions 1 and 0 and both have {1, 0} orders). Since this only affects FP8 GEMMs, I think it can make sense to introduce this here without changing the general logic.

@philipphack @wenscarl @reedwm what is the next step here?

  • reedwm was going to give the non-TN invocation of cuBLAS lt a try
  • wenscarl/philipphack were to look into reedwm's suggestion about moving some of the transposing logic to the matmul_util

Is that still the plan?

@reedwm
Copy link
Member

reedwm commented Feb 10, 2023

I confirmed the non-TN invocation works (the NN invocation, where neither input is transposed).

I'm working on modifying this PR to run with NN, then I'll share it and ask if you think it's clearer.

@philipphack @wenscarl do you plan on adding more test cases? If not I can also add some.

@pjannaty
Copy link
Contributor

Eventually can we please streamline and shorten the tests by perhaps parametrizing them? This would also help with readability. https://google.github.io/googletest/reference/testing.html#TEST_P

@wenscarl wenscarl force-pushed the cublaslt_fp8_matmul_war branch 2 times, most recently from 4bd7703 to adba8d7 Compare February 16, 2023 21:05
@google-ml-butler google-ml-butler bot added the ready to pull PR ready for merge process label Feb 24, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Feb 24, 2023
@google-ml-butler google-ml-butler bot added the kokoro:force-run Tests on submitted change label Feb 24, 2023
@kokoro-team kokoro-team removed kokoro:force-run Tests on submitted change labels Feb 24, 2023
PR Queue automation moved this from Approved by Reviewer to Reviewer Requested Changes Feb 24, 2023
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 25, 2023
Imported from GitHub PR tensorflow/tensorflow#59515

As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py).

A remaining restriction is the batch dimension still needs to be a leading dimension.
Copybara import of the project:

--
77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>:

Workaround cublasLt fp8 matmul restrictions.
--
706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>:

Work around cublasLt fp8 matmul restrictions

--
8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>:

Add MatrixIsColumnMajor
--
e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>:

Remove dead code

--
cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>:

Add gemm_rewrite test

--
460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>:

Cover 32 cases

--
6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>:

Abbreviate branches logics

--
64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>:

Parameterized test

--
6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>:

auto -> type names

--
96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>:

Update comments

--
07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>:

vector -> array

--
3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>:

fix typo

--
50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>:

Convert to NN config and add batched matmul tests

Merging this change closes #59515

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#59515 from wenscarl:cublaslt_fp8_matmul_war bff4bc682d5e0bc200e6261317ba1499960c3e45
PiperOrigin-RevId: 512191684
@pjannaty pjannaty changed the title Cublaslt fp8 matmul restriction work-around [XLA] Cublaslt fp8 matmul restriction work-around Feb 25, 2023
@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Feb 25, 2023
For a[b]_bitcast, MakeShape -> MakeShapeWithDenseLayout.
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 27, 2023
Imported from GitHub PR tensorflow/tensorflow#59515

As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py).

A remaining restriction is the batch dimension still needs to be a leading dimension.
Copybara import of the project:

--
77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>:

Workaround cublasLt fp8 matmul restrictions.
--
706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>:

Work around cublasLt fp8 matmul restrictions

--
8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>:

Add MatrixIsColumnMajor
--
e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>:

Remove dead code

--
cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>:

Add gemm_rewrite test

--
460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>:

Cover 32 cases

--
6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>:

Abbreviate branches logics

--
64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>:

Parameterized test

--
6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>:

auto -> type names

--
96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>:

Update comments

--
07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>:

vector -> array

--
3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>:

fix typo

--
50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>:

Convert to NN config and add batched matmul tests

Merging this change closes #59515

PiperOrigin-RevId: 512573287
copybara-service bot pushed a commit that referenced this pull request Feb 27, 2023
PR Queue automation moved this from Reviewer Requested Changes to Approved by Reviewer Feb 27, 2023
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Feb 27, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Feb 27, 2023
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 27, 2023
Imported from GitHub PR tensorflow/tensorflow#59515

As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py).

A remaining restriction is the batch dimension still needs to be a leading dimension.
Copybara import of the project:

--
fe390a9c6c94b25d0e78ebced21f8d74bb2de522 by shuw <shuw@nvidia.com>:

Update gemm_rewrite tests

--
de695a65cc10678afaa8a486443055f03db50f15 by wenscarl <shuw@nvidia.com>:

Preserve physical layout.

For a[b]_bitcast, MakeShape -> MakeShapeWithDenseLayout.

Merging this change closes #59515

PiperOrigin-RevId: 512599386
@copybara-service copybara-service bot merged commit 0e7e616 into tensorflow:master Feb 27, 2023
PR Queue automation moved this from Approved by Reviewer to Merged Feb 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review Pull request awaiting review comp:xla XLA ready to pull PR ready for merge process size:M CL Change Size: Medium
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet

7 participants