Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Cublaslt fp8 matmul restriction work-around #59515

Merged
merged 17 commits into from Feb 27, 2023
Merged
110 changes: 84 additions & 26 deletions tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
Expand Up @@ -605,14 +605,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return false;
}
if ((a_bitcast ? a_bitcast : a)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.lhs_batch_dimensions()
.size() !=
a_batch_dims.size() !=
2 ||
(b_bitcast ? b_bitcast : b)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.rhs_batch_dimensions()
.size() !=
b_batch_dims.size() !=
2) {
return false;
}
Expand Down Expand Up @@ -646,7 +642,37 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// and B are later exchanged, and B is transposed here instead.
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration.
TF_ASSIGN_OR_RETURN(bool is_col_major,
OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config, "d"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatrixIsColumnMajor doesn't exist. Did you forget to commit that change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Just added.

TF_ASSIGN_OR_RETURN(bool a_is_col_major,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

MatrixIsColumnMajor(instr, gemm_backend_config, "a"));
TF_ASSIGN_OR_RETURN(bool b_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "b"));
// Bitcast the operands so that the layout is monotonic and dim 0 is major.
std::vector<int64_t> a_dim_order;
a_dim_order.reserve(a_dims.size());
absl::Span<const int64_t> a_minor_to_major =
a->shape().layout().minor_to_major();
for (int i = 0; i < a_dims.size(); ++i) {
a_dim_order.emplace_back(
absl::c_find(a_minor_to_major,
is_col_major ? i : a_dims.size() - i - 1) -
a_minor_to_major.begin());
}
a = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(a_dim_order, a->shape()), a, a_dim_order));

std::vector<int64_t> b_dim_order;
b_dim_order.reserve(b_dims.size());
absl::Span<const int64_t> b_minor_to_major =
b->shape().layout().minor_to_major();
for (int i = 0; i < b_dims.size(); ++i) {
b_dim_order.emplace_back(
absl::c_find(b_minor_to_major,
is_col_major ? i : b_dims.size() - i - 1) -
b_minor_to_major.begin());
}
b = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(b_dim_order, b->shape()), b, b_dim_order));

// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
Expand Down Expand Up @@ -681,20 +707,43 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return transp_dims;
};

if (is_col_major) {
std::vector<int64_t> new_dim_order =
transp_dim_order(a, a_contracting_dims[0], a_batch_dims);
a = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(a->shape().element_type(),
transp_dims(a, new_dim_order)),
a, new_dim_order));
} else {
std::vector<int64_t> new_dim_order =
transp_dim_order(b, b_contracting_dims[0], b_batch_dims);
b = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(b->shape().element_type(),
transp_dims(b, new_dim_order)),
b, new_dim_order));
// Plain transpose on a or b. Plain transposes a matrix by permuting its
// dimension without changing storage order.
auto plain_transpose =
[&](HloInstruction **x,
const absl::Span<const int64_t> &contracting_dims,
const absl::Span<const int64_t> &batch_dims) {
std::vector<int64_t> new_dim_order =
transp_dim_order(*x, contracting_dims[0], batch_dims);
*x = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShapeWithDenseLayout(
(*x)->shape().element_type(), transp_dims(*x, new_dim_order),
(*x)->shape().layout().minor_to_major()),
*x, new_dim_order));
};

DotDimensionNumbers *dim_nums =
gemm_backend_config.mutable_dot_dimension_numbers();
int a_batch_dim_offset = a_batch_dims.size();
int b_batch_dim_offset = b_batch_dims.size();
// Apply necessary transposes to accommodate canonicalize matmul(lhs and rhs
// contracting dims are 1 and 0). Also assuming transpose folding pass later
// will remove duplcated transposes. The last transpose is required by
// cublas fp8 matmul restriction.
if (a_contracting_dims[0] == a_batch_dim_offset) {
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset + 1);
}
if (b_contracting_dims[0] == 1 + b_batch_dim_offset) {
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset);
}

if (a_is_col_major && b_is_col_major) {
plain_transpose(&a, a_contracting_dims, a_batch_dims);
} else if (!a_is_col_major && !b_is_col_major) {
plain_transpose(&b, b_contracting_dims, b_batch_dims);
} else if (a_is_col_major && !b_is_col_major) {
plain_transpose(&a, a_contracting_dims, a_batch_dims);
plain_transpose(&b, b_contracting_dims, b_batch_dims);
}

std::unique_ptr<HloInstruction> new_custom_call =
Expand Down Expand Up @@ -1195,9 +1244,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
output_dtype));
}

StatusOr<bool> OutputIsColumnMajor(
const HloInstruction *instr,
const GemmBackendConfig &gemm_backend_config) const {
StatusOr<bool> MatrixIsColumnMajor(
const HloInstruction *instr, const GemmBackendConfig &gemm_backend_config,
const std::string matrix_name = "output") const {
const HloInstruction *lhs = instr->operand(0);
const HloInstruction *rhs = instr->operand(1);

Expand All @@ -1214,7 +1263,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
/*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision));

return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor;
if (matrix_name == "lhs" || matrix_name == "a") {
return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "rhs" || matrix_name == "b") {
return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "output" || matrix_name == "d") {
return gemm_config.output_layout.order ==
MatrixLayout::Order::kColumnMajor;
} else {
return InternalError("Invalid matrix name.");
}
}

StatusOr<bool> GemmIsSupportedByCublasLt(
Expand Down Expand Up @@ -1270,7 +1328,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.dot_dimension_numbers();

TF_ASSIGN_OR_RETURN(bool output_is_column_major,
OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config));
if (!output_is_column_major) {
// cublasLt's matmul output is column major by default. This gemm requires
// the output to be in row major. Later we will swap lhs & rhs (and
Expand Down
70 changes: 70 additions & 0 deletions tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
Expand Up @@ -4466,6 +4466,76 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
)");
}

TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
if (!GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::HOPPER)) {
GTEST_SKIP()
<< "cuBLASLt FP8 kernels require Hopper or newer architecture.";
}

std::array<std::array<absl::string_view, 7>, 32> combinations;
int i = 0;
for (bool d_is_col : {false, true}) {
for (bool a_is_col : {false, true}) {
for (bool b_is_col : {true, true}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be {false, true}?

for (int lhs_contracting_dim : {0, 1}) {
for (int rhs_contracting_dim : {0, 1}) {
const absl::string_view lcd =
lhs_contracting_dim == 1 ? "{1}" : "{0}";
const absl::string_view rcd =
rhs_contracting_dim == 1 ? "{1}" : "{0}";
const absl::string_view a_logic =
lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
const absl::string_view b_logic =
rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
const absl::string_view a_physical = a_is_col ? "{0,1}" : "{1,0}";
const absl::string_view b_physical = b_is_col ? "{0,1}" : "{1,0}";
const absl::string_view output_physical =
d_is_col ? "{0,1}" : "{1,0}";
combinations[i++] =
std::array{lcd, rcd, a_logic, b_logic, a_physical, b_physical,
output_physical};
}
}
}
}
}
for (const auto& combination : combinations) {
absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
replacements["<<Lcd>>"] = std::get<0>(combination);
replacements["<<Rcd>>"] = std::get<1>(combination);
replacements["<<ALog>>"] = std::get<2>(combination);
replacements["<<BLog>>"] = std::get<3>(combination);
replacements["<<APhy>>"] = std::get<4>(combination);
replacements["<<BPhy>>"] = std::get<5>(combination);
replacements["<<OPhy>>"] = std::get<6>(combination);
const char* hlo_template = R"(
HloModule test

ENTRY test {
x = f8e4m3fn<<ALog>><<APhy>> parameter(0)
x_f32 = f32<<ALog>><<APhy>> convert(x)
x_scale = f32[] parameter(2)
x_scale_bcast = f32<<ALog>> broadcast(x_scale), dimensions={}
x_unscaled = f32<<ALog>> multiply(x_f32, x_scale_bcast)
y = f8e4m3fn<<BLog>><<BPhy>> parameter(1)
y_f32 = f32<<BLog>><<BPhy>> convert(y)
y_scale = f32[] parameter(3)
y_scale_bcast = f32<<BLog>> broadcast(y_scale), dimensions={}
y_unscaled = f32<<BLog>> multiply(y_f32, y_scale_bcast)
ROOT out = f32[64,16]<<OPhy>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
}
)";
const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.}));

MatchOptimizedHlo(hlo_text,
R"(
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
)");
}
};

class GemmRewriteAllocationTest : public GpuCodegenTest {
public:
void CheckNumberOfAllocations(const std::string& hlo,
Expand Down