Skip to content

Commit

Permalink
[XLA:GPU] Forbid fusing concatenations in dots when the non-contracti…
Browse files Browse the repository at this point in the history
…ng dimension is split.

This is not handled correctly by codegen at this point.

PiperOrigin-RevId: 621241252
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Apr 3, 2024
1 parent e0ad109 commit 3c2aa94
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
25 changes: 25 additions & 0 deletions third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,31 @@ ENTRY e {
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}));
}

TEST_F(TritonGemmTestAny,
DoNotFuseConcatenationOfSplitNonContractingDimension) {
const std::string hlo_text = R"(
HloModule m
ENTRY e {
x = bf16[2,128,10] parameter(0)
y = bf16[2,256,10] parameter(1)
concat = bf16[2,384,10] concatenate(x, y), dimensions={1}
z = bf16[10,20] parameter(2)
ROOT d = bf16[2,384,20] dot(concat, z), lhs_contracting_dims={2}, rhs_contracting_dims={0}
})";

MatchOptimizedHlo(hlo_text, R"(
; CHECK: ENTRY
; CHECK: concatenate
; CHECK: ROOT
; CHECK-SAME: fusion
; CHECK-SAME: kind=kCustom
; CHECK-SAME: "block_m"
)");

EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

class TritonGemmLevel2Test : public TritonGemmTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down
20 changes: 17 additions & 3 deletions third_party/xla/xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,9 +943,23 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo,
if (!std::holds_alternative<DotProperties>(properties)) {
return "Concatenations for now are only supported in GEMM fusions.";
}
auto dim = LogicalIndexOfLabeledDimension(
hlo.shape(), src_dim_order,
std::get<DotProperties>(properties).noncontracting_dimension);

int64_t noncontracting_dim_label =
std::get<DotProperties>(properties).noncontracting_dimension;
const FragmentOrders& src_dim_fragments_orders =
src_dim_order.DimFragmentsOrders();

auto noncontracting_dim_fragment_order =
src_dim_fragments_orders.find(noncontracting_dim_label);
if (noncontracting_dim_fragment_order != src_dim_fragments_orders.end()) {
if (noncontracting_dim_fragment_order->second.size() > 1) {
return "Concatenations on split non-contracting dimensions are "
"unsupported.";
}
}

auto dim = LogicalIndexOfLabeledDimension(hlo.shape(), src_dim_order,
noncontracting_dim_label);
if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) {
return "Unsupported concatenation.";
}
Expand Down

0 comments on commit 3c2aa94

Please sign in to comment.