Skip to content

Commit

Permalink
Merge pull request #59515 from wenscarl:cublaslt_fp8_matmul_war
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 5125993
  • Loading branch information
tensorflower-gardener committed Feb 27, 2023
2 parents 711c031 + de695a6 commit 0e7e616
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 140 deletions.
10 changes: 6 additions & 4 deletions tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
Expand Up @@ -717,14 +717,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// Shift any bitcasts to the unconverted and unscaled operands.
if (a_bitcast) {
a = instr->AddInstruction(a_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(a->shape().element_type(),
a_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
a->shape().element_type(), a_bitcast->shape().dimensions(),
a_bitcast->shape().layout().minor_to_major()),
{a}));
}
if (b_bitcast) {
b = instr->AddInstruction(b_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(b->shape().element_type(),
b_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
b->shape().element_type(), b_bitcast->shape().dimensions(),
b_bitcast->shape().layout().minor_to_major()),
{b}));
}

Expand Down
141 changes: 5 additions & 136 deletions tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
Expand Up @@ -4058,31 +4058,8 @@ TEST_F(CublasLtF8GemmRewriteTest, UnscaledABUnscaledDF8) {

MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-NEXT: [[FUSION:%[^ ]+]] = f8e4m3fn[16,16]{1,0} fusion()
; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[FUSION]], [[C1]], [[C1]], /*index=5*/[[C1]], [[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config="{
; CHECK-DAG: \"alpha_real\":1
; CHECK-DAG: \"alpha_imag\":0
; CHECK-DAG: \"beta\":0
; CHECK-DAG: \"dot_dimension_numbers\":{
; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"]
; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"]
; CHECK-DAG: \"lhs_batch_dimensions\":[]
; CHECK-DAG: \"rhs_batch_dimensions\":[]
; CHECK-DAG: }
; CHECK-DAG: \"precision_config\":{
; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]
; CHECK-DAG: }
; CHECK-DAG: \"epilogue\":\"DEFAULT\"
; CHECK: }"
)");
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
)");
}

TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8) {
Expand Down Expand Up @@ -4113,34 +4090,8 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8) {
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[16,16]{1,0} broadcast([[C0]]), dimensions={}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0_BCAST]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config="{
; CHECK-DAG: \"alpha_real\":1
; CHECK-DAG: \"alpha_imag\":0
; CHECK-DAG: \"beta\":0
; CHECK-DAG: \"dot_dimension_numbers\":{
; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"]
; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"]
; CHECK-DAG: \"lhs_batch_dimensions\":[]
; CHECK-DAG: \"rhs_batch_dimensions\":[]
; CHECK-DAG: }
; CHECK-DAG: \"precision_config\":{
; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]
; CHECK-DAG: }
; CHECK-DAG: \"epilogue\":\"DEFAULT\"
; CHECK: }"
)");
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
)");
}

TEST_F(CublasLtF8GemmRewriteTest, BitcastScaledABUnscaledDF8) {
Expand Down Expand Up @@ -4209,33 +4160,7 @@ TEST_F(CublasLtF8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[10,16,32], y: f8e4m3fn[10,32,16], x_scale: f32[], y_scale: f32[]) -> f32[10,16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[10,32,16]{2,1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0)
; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[10,16,16]{2,1,0} broadcast([[C0]]), dimensions={}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[10,16,16]{2,1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0_BCAST]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config="{
; CHECK-DAG: \"alpha_real\":1
; CHECK-DAG: \"alpha_imag\":0
; CHECK-DAG: \"beta\":0
; CHECK-DAG: \"dot_dimension_numbers\":{
; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"]
; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"]
; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"]
; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"]
; CHECK-DAG: }
; CHECK-DAG: \"precision_config\":{
; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]
; CHECK-DAG: }
; CHECK-DAG: \"epilogue\":\"DEFAULT\"
; CHECK: }"
)");
}

Expand Down Expand Up @@ -4430,35 +4355,7 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABScaledDF8) {
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 0.}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> f8e4m3fn[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0)
; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = bf16[16,16]{1,0} broadcast([[C0]]), dimensions={}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C1]], [[P4]])
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0_BCAST]], [[P2]], [[P3]], /*index=5*/[[C1]], [[P4_INV]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config="{
; CHECK-DAG: \"alpha_real\":1
; CHECK-DAG: \"alpha_imag\":0
; CHECK-DAG: \"beta\":0
; CHECK-DAG: \"dot_dimension_numbers\":{
; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"]
; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"]
; CHECK-DAG: \"lhs_batch_dimensions\":[]
; CHECK-DAG: \"rhs_batch_dimensions\":[]
; CHECK-DAG: }
; CHECK-DAG: \"precision_config\":{
; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]
; CHECK-DAG: }
; CHECK-DAG: \"epilogue\":\"DEFAULT\"
; CHECK: }"
)");
}

Expand Down Expand Up @@ -4626,35 +4523,7 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 0.}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> (f8e4m3fn[16,16], f32[]) {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]])
; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0)
; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = bf16[16,16]{1,0} broadcast([[C0]]), dimensions={}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C1]], [[P4]])
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0_BCAST]], [[P2]], [[P3]], /*index=5*/[[C1]], [[P4_INV]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config="{
; CHECK-DAG: \"alpha_real\":1
; CHECK-DAG: \"alpha_imag\":0
; CHECK-DAG: \"beta\":0
; CHECK-DAG: \"dot_dimension_numbers\":{
; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"]
; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"]
; CHECK-DAG: \"lhs_batch_dimensions\":[]
; CHECK-DAG: \"rhs_batch_dimensions\":[]
; CHECK-DAG: }
; CHECK-DAG: \"precision_config\":{
; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]
; CHECK-DAG: }
; CHECK-DAG: \"epilogue\":\"DEFAULT\"
; CHECK: }"
)");
}

Expand Down Expand Up @@ -4745,7 +4614,7 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8ParameterizedBatched) {
}) +
"}";
};
std::array<std::array<std::string, 7>, 384> combinations;
std::array<std::array<std::string, 7>, 96> combinations;
std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
int i = 0;
for (bool o_is_col : {false, true}) {
Expand Down

0 comments on commit 0e7e616

Please sign in to comment.