New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] Cublaslt fp8 matmul restriction work-around #59515
Changes from 12 commits
77de65a
0cc0fef
706ad9a
8f48569
e0e67a4
cfea2cb
460f616
6f724ff
64f153a
6146558
96e7588
07db882
3089065
50eaec6
bff4bc6
fe390a9
de695a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -605,14 +605,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
return false; | ||
} | ||
if ((a_bitcast ? a_bitcast : a)->shape().dimensions_size() - | ||
gemm_backend_config.dot_dimension_numbers() | ||
.lhs_batch_dimensions() | ||
.size() != | ||
a_batch_dims.size() != | ||
2 || | ||
(b_bitcast ? b_bitcast : b)->shape().dimensions_size() - | ||
gemm_backend_config.dot_dimension_numbers() | ||
.rhs_batch_dimensions() | ||
.size() != | ||
b_batch_dims.size() != | ||
2) { | ||
return false; | ||
} | ||
|
@@ -646,7 +642,37 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
// and B are later exchanged, and B is transposed here instead. | ||
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration. | ||
TF_ASSIGN_OR_RETURN(bool is_col_major, | ||
OutputIsColumnMajor(instr, gemm_backend_config)); | ||
MatrixIsColumnMajor(instr, gemm_backend_config, "d")); | ||
TF_ASSIGN_OR_RETURN(bool a_is_col_major, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
MatrixIsColumnMajor(instr, gemm_backend_config, "a")); | ||
TF_ASSIGN_OR_RETURN(bool b_is_col_major, | ||
MatrixIsColumnMajor(instr, gemm_backend_config, "b")); | ||
// Bitcast the operands so that the layout is monotonic and dim 0 is major. | ||
std::vector<int64_t> a_dim_order; | ||
a_dim_order.reserve(a_dims.size()); | ||
absl::Span<const int64_t> a_minor_to_major = | ||
a->shape().layout().minor_to_major(); | ||
for (int i = 0; i < a_dims.size(); ++i) { | ||
a_dim_order.emplace_back( | ||
absl::c_find(a_minor_to_major, | ||
is_col_major ? i : a_dims.size() - i - 1) - | ||
a_minor_to_major.begin()); | ||
} | ||
a = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::PermuteDimensions(a_dim_order, a->shape()), a, a_dim_order)); | ||
|
||
std::vector<int64_t> b_dim_order; | ||
b_dim_order.reserve(b_dims.size()); | ||
absl::Span<const int64_t> b_minor_to_major = | ||
b->shape().layout().minor_to_major(); | ||
for (int i = 0; i < b_dims.size(); ++i) { | ||
b_dim_order.emplace_back( | ||
absl::c_find(b_minor_to_major, | ||
is_col_major ? i : b_dims.size() - i - 1) - | ||
b_minor_to_major.begin()); | ||
} | ||
b = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::PermuteDimensions(b_dim_order, b->shape()), b, b_dim_order)); | ||
|
||
// Identify the dimensional order which describes a transpose of the | ||
// contracting and non-contracting dimensions of the GEMM. | ||
|
@@ -681,20 +707,43 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
return transp_dims; | ||
}; | ||
|
||
if (is_col_major) { | ||
std::vector<int64_t> new_dim_order = | ||
transp_dim_order(a, a_contracting_dims[0], a_batch_dims); | ||
a = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::MakeShape(a->shape().element_type(), | ||
transp_dims(a, new_dim_order)), | ||
a, new_dim_order)); | ||
} else { | ||
std::vector<int64_t> new_dim_order = | ||
transp_dim_order(b, b_contracting_dims[0], b_batch_dims); | ||
b = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::MakeShape(b->shape().element_type(), | ||
transp_dims(b, new_dim_order)), | ||
b, new_dim_order)); | ||
// Plain transpose on a or b. Plain transposes a matrix by permuting its | ||
// dimension without changing storage order. | ||
auto plain_transpose = | ||
[&](HloInstruction **x, | ||
const absl::Span<const int64_t> &contracting_dims, | ||
const absl::Span<const int64_t> &batch_dims) { | ||
std::vector<int64_t> new_dim_order = | ||
transp_dim_order(*x, contracting_dims[0], batch_dims); | ||
*x = instr->AddInstruction(HloInstruction::CreateTranspose( | ||
ShapeUtil::MakeShapeWithDenseLayout( | ||
(*x)->shape().element_type(), transp_dims(*x, new_dim_order), | ||
(*x)->shape().layout().minor_to_major()), | ||
*x, new_dim_order)); | ||
}; | ||
|
||
DotDimensionNumbers *dim_nums = | ||
gemm_backend_config.mutable_dot_dimension_numbers(); | ||
int a_batch_dim_offset = a_batch_dims.size(); | ||
int b_batch_dim_offset = b_batch_dims.size(); | ||
// Apply necessary transposes to accommodate canonicalize matmul(lhs and rhs | ||
// contracting dims are 1 and 0). Also assuming transpose folding pass later | ||
// will remove duplcated transposes. The last transpose is required by | ||
// cublas fp8 matmul restriction. | ||
if (a_contracting_dims[0] == a_batch_dim_offset) { | ||
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset + 1); | ||
} | ||
if (b_contracting_dims[0] == 1 + b_batch_dim_offset) { | ||
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset); | ||
} | ||
|
||
if (a_is_col_major && b_is_col_major) { | ||
plain_transpose(&a, a_contracting_dims, a_batch_dims); | ||
} else if (!a_is_col_major && !b_is_col_major) { | ||
plain_transpose(&b, b_contracting_dims, b_batch_dims); | ||
} else if (a_is_col_major && !b_is_col_major) { | ||
plain_transpose(&a, a_contracting_dims, a_batch_dims); | ||
plain_transpose(&b, b_contracting_dims, b_batch_dims); | ||
} | ||
|
||
std::unique_ptr<HloInstruction> new_custom_call = | ||
|
@@ -1195,9 +1244,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
output_dtype)); | ||
} | ||
|
||
StatusOr<bool> OutputIsColumnMajor( | ||
const HloInstruction *instr, | ||
const GemmBackendConfig &gemm_backend_config) const { | ||
StatusOr<bool> MatrixIsColumnMajor( | ||
const HloInstruction *instr, const GemmBackendConfig &gemm_backend_config, | ||
const std::string matrix_name = "output") const { | ||
const HloInstruction *lhs = instr->operand(0); | ||
const HloInstruction *rhs = instr->operand(1); | ||
|
||
|
@@ -1214,7 +1263,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), | ||
/*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision)); | ||
|
||
return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor; | ||
if (matrix_name == "lhs" || matrix_name == "a") { | ||
return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor; | ||
} else if (matrix_name == "rhs" || matrix_name == "b") { | ||
return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor; | ||
} else if (matrix_name == "output" || matrix_name == "d") { | ||
return gemm_config.output_layout.order == | ||
MatrixLayout::Order::kColumnMajor; | ||
} else { | ||
return InternalError("Invalid matrix name."); | ||
} | ||
} | ||
|
||
StatusOr<bool> GemmIsSupportedByCublasLt( | ||
|
@@ -1270,7 +1328,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |
gemm_backend_config.dot_dimension_numbers(); | ||
|
||
TF_ASSIGN_OR_RETURN(bool output_is_column_major, | ||
OutputIsColumnMajor(instr, gemm_backend_config)); | ||
MatrixIsColumnMajor(instr, gemm_backend_config)); | ||
if (!output_is_column_major) { | ||
// cublasLt's matmul output is column major by default. This gemm requires | ||
// the output to be in row major. Later we will swap lhs & rhs (and | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4466,6 +4466,76 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { | |
)"); | ||
} | ||
|
||
TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { | ||
if (!GetCudaComputeCapability().IsAtLeast( | ||
se::CudaComputeCapability::HOPPER)) { | ||
GTEST_SKIP() | ||
<< "cuBLASLt FP8 kernels require Hopper or newer architecture."; | ||
} | ||
|
||
std::array<std::array<absl::string_view, 7>, 32> combinations; | ||
int i = 0; | ||
for (bool d_is_col : {false, true}) { | ||
for (bool a_is_col : {false, true}) { | ||
for (bool b_is_col : {true, true}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be {false, true}? |
||
for (int lhs_contracting_dim : {0, 1}) { | ||
for (int rhs_contracting_dim : {0, 1}) { | ||
const absl::string_view lcd = | ||
lhs_contracting_dim == 1 ? "{1}" : "{0}"; | ||
const absl::string_view rcd = | ||
rhs_contracting_dim == 1 ? "{1}" : "{0}"; | ||
const absl::string_view a_logic = | ||
lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]"; | ||
const absl::string_view b_logic = | ||
rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]"; | ||
const absl::string_view a_physical = a_is_col ? "{0,1}" : "{1,0}"; | ||
const absl::string_view b_physical = b_is_col ? "{0,1}" : "{1,0}"; | ||
const absl::string_view output_physical = | ||
d_is_col ? "{0,1}" : "{1,0}"; | ||
combinations[i++] = | ||
std::array{lcd, rcd, a_logic, b_logic, a_physical, b_physical, | ||
output_physical}; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
for (const auto& combination : combinations) { | ||
absl::flat_hash_map<absl::string_view, absl::string_view> replacements; | ||
replacements["<<Lcd>>"] = std::get<0>(combination); | ||
replacements["<<Rcd>>"] = std::get<1>(combination); | ||
replacements["<<ALog>>"] = std::get<2>(combination); | ||
replacements["<<BLog>>"] = std::get<3>(combination); | ||
replacements["<<APhy>>"] = std::get<4>(combination); | ||
replacements["<<BPhy>>"] = std::get<5>(combination); | ||
replacements["<<OPhy>>"] = std::get<6>(combination); | ||
const char* hlo_template = R"( | ||
HloModule test | ||
|
||
ENTRY test { | ||
x = f8e4m3fn<<ALog>><<APhy>> parameter(0) | ||
x_f32 = f32<<ALog>><<APhy>> convert(x) | ||
x_scale = f32[] parameter(2) | ||
x_scale_bcast = f32<<ALog>> broadcast(x_scale), dimensions={} | ||
x_unscaled = f32<<ALog>> multiply(x_f32, x_scale_bcast) | ||
y = f8e4m3fn<<BLog>><<BPhy>> parameter(1) | ||
y_f32 = f32<<BLog>><<BPhy>> convert(y) | ||
y_scale = f32[] parameter(3) | ||
y_scale_bcast = f32<<BLog>> broadcast(y_scale), dimensions={} | ||
y_unscaled = f32<<BLog>> multiply(y_f32, y_scale_bcast) | ||
ROOT out = f32[64,16]<<OPhy>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>> | ||
} | ||
)"; | ||
const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements); | ||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.})); | ||
|
||
MatchOptimizedHlo(hlo_text, | ||
R"( | ||
; CHECK: custom_call_target="__cublas$lt$matmul$f8", | ||
)"); | ||
} | ||
}; | ||
|
||
class GemmRewriteAllocationTest : public GpuCodegenTest { | ||
public: | ||
void CheckNumberOfAllocations(const std::string& hlo, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatrixIsColumnMajor
doesn't exist. Did you forget to commit that change?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Just added.