Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Cublaslt fp8 matmul restriction work-around #59515

Merged
merged 17 commits into from Feb 27, 2023
Merged
113 changes: 75 additions & 38 deletions tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
Expand Up @@ -523,7 +523,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
se::CudaComputeCapability::HOPPER)) {
return false;
}

#if CUDA_VERSION < 11080
// FP8 GEMM kernels are only available with CUDA 11.8 and above
return false;
Expand Down Expand Up @@ -625,14 +624,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return false;
}
if ((a_bitcast ? a_bitcast : a)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.lhs_batch_dimensions()
.size() !=
a_batch_dims.size() !=
2 ||
(b_bitcast ? b_bitcast : b)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.rhs_batch_dimensions()
.size() !=
b_batch_dims.size() !=
2) {
return false;
}
Expand All @@ -650,24 +645,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// Shift any bitcasts to the unconverted and unscaled operands.
if (a_bitcast) {
a = instr->AddInstruction(a_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(a->shape().element_type(),
a_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
a->shape().element_type(), a_bitcast->shape().dimensions(),
a_bitcast->shape().layout().minor_to_major()),
{a}));
}
if (b_bitcast) {
b = instr->AddInstruction(b_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(b->shape().element_type(),
b_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
b->shape().element_type(), b_bitcast->shape().dimensions(),
b_bitcast->shape().layout().minor_to_major()),
{b}));
}

// cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
// be transposed. If the result of the GEMM is not in column major order, A
// and B are later exchanged, and B is transposed here instead.
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration.
TF_ASSIGN_OR_RETURN(bool is_col_major,
OutputIsColumnMajor(instr, gemm_backend_config));

// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
auto transp_dim_order =
Expand Down Expand Up @@ -700,23 +690,62 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
return transp_dims;
};
// Plain transpose on a or b. Plain transposes a matrix by permuting its
// dimension without changing storage order.
auto plain_transpose =
[&](HloInstruction **x,
const absl::Span<const int64_t> &contracting_dims,
const absl::Span<const int64_t> &batch_dims) {
std::vector<int64_t> new_dim_order =
transp_dim_order(*x, contracting_dims[0], batch_dims);
*x = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShapeWithDenseLayout(
(*x)->shape().element_type(), transp_dims(*x, new_dim_order),
(*x)->shape().layout().minor_to_major()),
*x, new_dim_order));
};

if (is_col_major) {
std::vector<int64_t> new_dim_order =
transp_dim_order(a, a_contracting_dims[0], a_batch_dims);
a = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(a->shape().element_type(),
transp_dims(a, new_dim_order)),
a, new_dim_order));
} else {
std::vector<int64_t> new_dim_order =
transp_dim_order(b, b_contracting_dims[0], b_batch_dims);
b = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(b->shape().element_type(),
transp_dims(b, new_dim_order)),
b, new_dim_order));
// cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
// be transposed. If the result of the GEMM is not in column major order, A
// and B are later exchanged, and B is transposed here instead.
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration.
TF_ASSIGN_OR_RETURN(bool a_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "a"));
TF_ASSIGN_OR_RETURN(bool b_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "b"));

// Apply necessary transposes to accommodate canonicalize matmul(lhs and rhs
// contracting dims are 1 and 0). Also assuming transpose folding pass later
// will remove duplcated transposes. The last transpose is required by
// cublas fp8 matmul restriction.
DotDimensionNumbers *dim_nums =
gemm_backend_config.mutable_dot_dimension_numbers();
int a_batch_dim_offset = a_batch_dims.size();
int b_batch_dim_offset = b_batch_dims.size();

if (a_is_col_major) {
// Swap contracting dimensions and convert a to row major
CHECK(a_contracting_dims[0] == a_batch_dim_offset ||
a_contracting_dims[0] == a_batch_dim_offset + 1);
if (a_contracting_dims[0] == a_batch_dim_offset) {
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset + 1);
} else {
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset);
}
plain_transpose(&a, a_contracting_dims, a_batch_dims);
}

if (!b_is_col_major) {
// Swap contracting dimensions and convert b to col major
CHECK(b_contracting_dims[0] == b_batch_dim_offset ||
b_contracting_dims[0] == b_batch_dim_offset + 1);
if (b_contracting_dims[0] == b_batch_dim_offset) {
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset + 1);
} else {
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset);
}
plain_transpose(&b, b_contracting_dims, b_batch_dims);
}
std::unique_ptr<HloInstruction> new_custom_call =
HloInstruction::CreateCustomCall(
instr->shape(),
Expand All @@ -728,7 +757,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(new_custom_call)));

return true;
}

Expand Down Expand Up @@ -1217,9 +1245,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
output_dtype));
}

StatusOr<bool> OutputIsColumnMajor(
const HloInstruction *instr,
const GemmBackendConfig &gemm_backend_config) const {
StatusOr<bool> MatrixIsColumnMajor(
const HloInstruction *instr, const GemmBackendConfig &gemm_backend_config,
const std::string matrix_name = "output") const {
const HloInstruction *lhs = instr->operand(0);
const HloInstruction *rhs = instr->operand(1);

Expand All @@ -1236,7 +1264,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
/*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision));

return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor;
if (matrix_name == "lhs" || matrix_name == "a") {
return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "rhs" || matrix_name == "b") {
return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "output" || matrix_name == "d") {
return gemm_config.output_layout.order ==
MatrixLayout::Order::kColumnMajor;
} else {
return InternalError("Invalid matrix name.");
}
}

StatusOr<bool> GemmIsSupportedByCublasLt(
Expand Down Expand Up @@ -1292,7 +1329,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.dot_dimension_numbers();

TF_ASSIGN_OR_RETURN(bool output_is_column_major,
OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config));
if (!output_is_column_major) {
// cublasLt's matmul output is column major by default. This gemm requires
// the output to be in row major. Later we will swap lhs & rhs (and
Expand Down
11 changes: 4 additions & 7 deletions tensorflow/compiler/xla/service/gpu/matmul_utils.cc
Expand Up @@ -806,13 +806,10 @@ StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
lhs_layout.batch_size = batch_size;
rhs_layout.batch_size = batch_size;

// cuBLASLt FP8 GEMM kernels require A (i.e. lhs) to be transposed.
se::blas::Transpose trans_a;
if (lhs_layout.dtype == F8E4M3FN || lhs_layout.dtype == F8E5M2) {
trans_a = se::blas::Transpose::kTranspose;
} else {
trans_a = se::blas::Transpose::kNoTranspose;
}
// cuBLASLt FP8 GEMM kernels require A (i.e. lhs) to be transposed and this
// equivalents to A being row major stored if no transpose is explicitly
// applied on A.
se::blas::Transpose trans_a = se::blas::Transpose::kNoTranspose;

bool must_swap_operands =
MakeOutputColumnMajor(lhs_layout, rhs_layout, c_layout, output_layout);
Expand Down