Skip to content

Commit

Permalink
Fix broken GemmRewriteTest.BF16GemmCodeGen on Hopper.
Browse files Browse the repository at this point in the history
This was broken by openxla/xla@4e09e73.

The issue is that the optimized HLO was changed by having a native BF16 multiply, so the filecheck string had to be changed. Because the multiply is done in BF16 instead of FP32, the tolerance must also be lowered.

PiperOrigin-RevId: 616380852
  • Loading branch information
reedwm authored and tensorflower-gardener committed Mar 16, 2024
1 parent 3efd2d3 commit d57dfed
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,32 @@ ENTRY bf16gemm {
}
)";

MatchOptimizedHlo(hlo_text, R"(
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
)");
if (CudaOrRocmCheck(9, 0, Switch::False)) {
// The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
// native BF16 multiply support.
MatchOptimizedHlo(hlo_text, R"(
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
; CHECK: [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
; CHECK: [[INSTR_4:%[^ ]+]] = f32[] constant(0)
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
; CHECK: ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
)");
} else {
MatchOptimizedHlo(hlo_text, R"(
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
)");
}

EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
}

TEST_F(GemmRewriteTest, BF16Transpose) {
Expand Down

0 comments on commit d57dfed

Please sign in to comment.